diff --git a/.clang-tidy b/.clang-tidy index 310c3d182..5bc63bc6e 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -13,6 +13,7 @@ Checks: > -readability-magic-numbers, -readability-uppercase-literal-suffix, -readability-simplify-boolean-expr, + -readability-math-missing-parentheses, clang-analyzer-*, -clang-analyzer-security.insecureAPI.DeprecatedOrUnsafeBufferHandling, performance-*, diff --git a/.devops/cpu.Dockerfile b/.devops/cpu.Dockerfile index aa2aa0312..9459f08c1 100644 --- a/.devops/cpu.Dockerfile +++ b/.devops/cpu.Dockerfile @@ -14,9 +14,9 @@ WORKDIR /app COPY . . RUN if [ "$TARGETARCH" = "amd64" ]; then \ - cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON; \ + cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON; \ elif [ "$TARGETARCH" = "arm64" ]; then \ - cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=${GGML_CPU_ARM_ARCH}; \ + cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DLLAMA_BUILD_TESTS=OFF -DGGML_CPU_ARM_ARCH=${GGML_CPU_ARM_ARCH}; \ else \ echo "Unsupported architecture"; \ exit 1; \ diff --git a/.devops/cuda.Dockerfile b/.devops/cuda.Dockerfile index 8ae57d2e2..94f143397 100644 --- a/.devops/cuda.Dockerfile +++ b/.devops/cuda.Dockerfile @@ -21,7 +21,7 @@ COPY . . RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \ export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=${CUDA_DOCKER_ARCH}"; \ fi && \ - cmake -B build -DGGML_NATIVE=OFF -DGGML_CUDA=ON -DLLAMA_CURL=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \ + cmake -B build -DGGML_NATIVE=OFF -DGGML_CUDA=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DLLAMA_BUILD_TESTS=OFF ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \ cmake --build build --config Release -j$(nproc) RUN mkdir -p /app/lib && \ diff --git a/.devops/intel.Dockerfile b/.devops/intel.Dockerfile index 091e1dc5d..9ce80a71e 100644 --- a/.devops/intel.Dockerfile +++ b/.devops/intel.Dockerfile @@ -1,4 +1,4 @@ -ARG ONEAPI_VERSION=2025.0.0-0-devel-ubuntu22.04 +ARG ONEAPI_VERSION=2025.1.1-0-devel-ubuntu24.04 ## Build Image @@ -17,7 +17,7 @@ RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \ && export OPT_SYCL_F16="-DGGML_SYCL_F16=ON"; \ fi && \ echo "Building with dynamic libs" && \ - cmake -B build -DGGML_NATIVE=OFF -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DLLAMA_CURL=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON ${OPT_SYCL_F16} && \ + cmake -B build -DGGML_NATIVE=OFF -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DLLAMA_BUILD_TESTS=OFF ${OPT_SYCL_F16} && \ cmake --build build --config Release -j$(nproc) RUN mkdir -p /app/lib && \ @@ -49,19 +49,23 @@ COPY --from=build /app/full /app WORKDIR /app -RUN apt-get update \ - && apt-get install -y \ - git \ - python3 \ - python3-pip \ - && pip install --upgrade pip setuptools wheel \ - && pip install -r requirements.txt \ - && apt autoremove -y \ - && apt clean -y \ - && rm -rf /tmp/* /var/tmp/* \ - && find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete \ - && find /var/cache -type f -delete +RUN apt-get update && \ + apt-get install -y \ + git \ + python3 \ + python3-pip \ + python3-venv && \ + python3 -m venv /opt/venv && \ + . /opt/venv/bin/activate && \ + pip install --upgrade pip setuptools wheel && \ + pip install -r requirements.txt && \ + apt autoremove -y && \ + apt clean -y && \ + rm -rf /tmp/* /var/tmp/* && \ + find /var/cache/apt/archives /var/lib/apt/lists -not -name lock -type f -delete && \ + find /var/cache -type f -delete +ENV PATH="/opt/venv/bin:$PATH" ENTRYPOINT ["/app/tools.sh"] diff --git a/.devops/llama-cli-cann.Dockerfile b/.devops/llama-cli-cann.Dockerfile index 0eb1af87c..ef43d78cd 100644 --- a/.devops/llama-cli-cann.Dockerfile +++ b/.devops/llama-cli-cann.Dockerfile @@ -22,7 +22,7 @@ ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/runtime/lib64/stub:$LD_LIBRARY_PATH RUN echo "Building with static libs" && \ source /usr/local/Ascend/ascend-toolkit/set_env.sh --force && \ - cmake -B build -DGGML_NATIVE=OFF -DGGML_CANN=ON -DBUILD_SHARED_LIBS=OFF && \ + cmake -B build -DGGML_NATIVE=OFF -DGGML_CANN=ON -DBUILD_SHARED_LIBS=OFF -DLLAMA_BUILD_TESTS=OFF && \ cmake --build build --config Release --target llama-cli # TODO: use image with NNRT diff --git a/.devops/musa.Dockerfile b/.devops/musa.Dockerfile index 261a2823a..87ce2393f 100644 --- a/.devops/musa.Dockerfile +++ b/.devops/musa.Dockerfile @@ -1,10 +1,10 @@ ARG UBUNTU_VERSION=22.04 # This needs to generally match the container host's environment. -ARG MUSA_VERSION=rc3.1.1 +ARG MUSA_VERSION=rc4.0.1 # Target the MUSA build image -ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-devel-ubuntu${UBUNTU_VERSION} +ARG BASE_MUSA_DEV_CONTAINER=mthreads/musa:${MUSA_VERSION}-mudnn-devel-ubuntu${UBUNTU_VERSION} -ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-runtime-ubuntu${UBUNTU_VERSION} +ARG BASE_MUSA_RUN_CONTAINER=mthreads/musa:${MUSA_VERSION}-mudnn-runtime-ubuntu${UBUNTU_VERSION} FROM ${BASE_MUSA_DEV_CONTAINER} AS build @@ -21,21 +21,14 @@ RUN apt-get update && \ libcurl4-openssl-dev \ libgomp1 -COPY requirements.txt requirements.txt -COPY requirements requirements - -RUN pip install --upgrade pip setuptools wheel \ - && pip install -r requirements.txt - WORKDIR /app COPY . . -# Use the default MUSA archs if not specified RUN if [ "${MUSA_DOCKER_ARCH}" != "default" ]; then \ export CMAKE_ARGS="-DMUSA_ARCHITECTURES=${MUSA_DOCKER_ARCH}"; \ fi && \ - cmake -B build -DGGML_NATIVE=OFF -DGGML_MUSA=ON -DLLAMA_CURL=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \ + cmake -B build -DGGML_NATIVE=OFF -DGGML_MUSA=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DLLAMA_BUILD_TESTS=OFF ${CMAKE_ARGS} -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined . && \ cmake --build build --config Release -j$(nproc) RUN mkdir -p /app/lib && \ diff --git a/.devops/rocm.Dockerfile b/.devops/rocm.Dockerfile index a1b34723a..1c00f1b9c 100644 --- a/.devops/rocm.Dockerfile +++ b/.devops/rocm.Dockerfile @@ -40,7 +40,7 @@ WORKDIR /app COPY . . RUN HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \ - cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=$ROCM_DOCKER_ARCH -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DCMAKE_BUILD_TYPE=Release -DLLAMA_CURL=ON \ + cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=$ROCM_DOCKER_ARCH -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DCMAKE_BUILD_TYPE=Release -DLLAMA_BUILD_TESTS=OFF \ && cmake --build build --config Release -j$(nproc) RUN mkdir -p /app/lib \ diff --git a/.devops/vulkan.Dockerfile b/.devops/vulkan.Dockerfile index f8f3072e9..fcd81ffa1 100644 --- a/.devops/vulkan.Dockerfile +++ b/.devops/vulkan.Dockerfile @@ -16,7 +16,7 @@ WORKDIR /app COPY . . -RUN cmake -B build -DGGML_NATIVE=OFF -DGGML_VULKAN=1 -DLLAMA_CURL=1 -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON && \ +RUN cmake -B build -DGGML_NATIVE=OFF -DGGML_VULKAN=1 -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON && \ cmake --build build --config Release -j$(nproc) RUN mkdir -p /app/lib && \ diff --git a/.editorconfig b/.editorconfig index 5d63d0a51..c90b171f5 100644 --- a/.editorconfig +++ b/.editorconfig @@ -21,15 +21,15 @@ indent_style = tab [prompts/*.txt] insert_final_newline = unset -[examples/server/public/*] +[tools/server/public/*] indent_size = 2 -[examples/server/public/deps_*] +[tools/server/public/deps_*] trim_trailing_whitespace = unset indent_style = unset indent_size = unset -[examples/server/deps_*] +[tools/server/deps_*] trim_trailing_whitespace = unset indent_style = unset indent_size = unset @@ -37,7 +37,7 @@ indent_size = unset [examples/llama.swiftui/llama.swiftui.xcodeproj/*] indent_style = tab -[examples/cvector-generator/*.txt] +[tools/cvector-generator/*.txt] trim_trailing_whitespace = unset insert_final_newline = unset @@ -48,3 +48,7 @@ end_of_line = unset charset = unset trim_trailing_whitespace = unset insert_final_newline = unset + +[vendor/miniaudio/miniaudio.h] +trim_trailing_whitespace = unset +insert_final_newline = unset diff --git a/.flake8 b/.flake8 index d64c2564a..669d231f1 100644 --- a/.flake8 +++ b/.flake8 @@ -2,8 +2,9 @@ max-line-length = 125 ignore = E203,E211,E221,E225,E231,E241,E251,E261,E266,E501,E701,E704,W503 exclude = - # Do not traverse examples + # Do not traverse examples and tools examples, + tools, # Do not include package initializers __init__.py, # No need to traverse our git directory diff --git a/.github/actions/get-tag-name/action.yml b/.github/actions/get-tag-name/action.yml new file mode 100644 index 000000000..7ace23b2a --- /dev/null +++ b/.github/actions/get-tag-name/action.yml @@ -0,0 +1,22 @@ +name: "Determine tag name" +description: "Determine the tag name to use for a release" +outputs: + name: + description: "The name of the tag" + value: ${{ steps.tag.outputs.name }} + +runs: + using: "composite" + steps: + - name: Determine tag name + id: tag + shell: bash + run: | + BUILD_NUMBER="$(git rev-list --count HEAD)" + SHORT_HASH="$(git rev-parse --short=7 HEAD)" + if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then + echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT + else + SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') + echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT + fi diff --git a/.github/actions/windows-setup-cuda/action.yml b/.github/actions/windows-setup-cuda/action.yml new file mode 100644 index 000000000..5575caeca --- /dev/null +++ b/.github/actions/windows-setup-cuda/action.yml @@ -0,0 +1,67 @@ +name: "Windows - Setup CUDA Toolkit" +description: "Setup CUDA Toolkit for Windows" +inputs: + cuda_version: + description: "CUDA toolkit version" + required: true + +runs: + using: "composite" + steps: + - name: Install Cuda Toolkit 11.7 + if: ${{ inputs.cuda_version == '11.7' }} + shell: pwsh + run: | + mkdir -p "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" + choco install unzip -y + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-11.7.99-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-11.7.99-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-11.7.99-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcublas/windows-x86_64/libcublas-windows-x86_64-11.7.4.6-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-11.7.91-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-11.7.91-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-11.7.101-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-11.7.91-archive.zip" + unzip '*.zip' -d "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_cudart-windows-x86_64-11.7.99-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_nvcc-windows-x86_64-11.7.99-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_nvrtc-windows-x86_64-11.7.99-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\libcublas-windows-x86_64-11.7.4.6-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_nvtx-windows-x86_64-11.7.91-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\visual_studio_integration-windows-x86_64-11.7.91-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_nvprof-windows-x86_64-11.7.101-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_cccl-windows-x86_64-11.7.91-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y + echo "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 + echo "CUDA_PATH_V11_7=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 + + - name: Install Cuda Toolkit 12.4 + if: ${{ inputs.cuda_version == '12.4' }} + shell: pwsh + run: | + mkdir -p "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" + choco install unzip -y + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-12.4.127-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-12.4.131-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-12.4.127-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcublas/windows-x86_64/libcublas-windows-x86_64-12.4.5.8-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-12.4.127-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_profiler_api/windows-x86_64/cuda_profiler_api-windows-x86_64-12.4.127-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-12.4.127-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-12.4.127-archive.zip" + curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-12.4.127-archive.zip" + unzip '*.zip' -d "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_cudart-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_nvcc-windows-x86_64-12.4.131-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_nvrtc-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\libcublas-windows-x86_64-12.4.5.8-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_nvtx-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_profiler_api-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\visual_studio_integration-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_nvprof-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_cccl-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y + echo "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 + echo "CUDA_PATH_V12_4=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 diff --git a/.github/actions/windows-setup-curl/action.yml b/.github/actions/windows-setup-curl/action.yml index 5d76da3d7..446f799fa 100644 --- a/.github/actions/windows-setup-curl/action.yml +++ b/.github/actions/windows-setup-curl/action.yml @@ -5,6 +5,10 @@ inputs: description: 'CURL version' required: false default: '8.6.0_6' + architecture: + description: 'Architecture of the libcurl to download' + required: false + default: 'win64' outputs: curl_path: description: "Path to the downloaded libcurl" @@ -18,8 +22,9 @@ runs: shell: powershell env: CURL_VERSION: ${{ inputs.curl_version }} + ARCHITECTURE: ${{ inputs.architecture }} run: | - curl.exe -o $env:RUNNER_TEMP/curl.zip -L "https://curl.se/windows/dl-${env:CURL_VERSION}/curl-${env:CURL_VERSION}-win64-mingw.zip" + curl.exe -o $env:RUNNER_TEMP/curl.zip -L "https://curl.se/windows/dl-${env:CURL_VERSION}/curl-${env:CURL_VERSION}-${env:ARCHITECTURE}-mingw.zip" mkdir $env:RUNNER_TEMP/libcurl tar.exe -xvf $env:RUNNER_TEMP/curl.zip --strip-components=1 -C $env:RUNNER_TEMP/libcurl echo "curl_path=$env:RUNNER_TEMP/libcurl" >> $env:GITHUB_OUTPUT diff --git a/.github/labeler.yml b/.github/labeler.yml index 1b47bc968..3c2f67707 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -45,7 +45,9 @@ build: - CMakePresets.json examples: - changed-files: - - any-glob-to-any-file: examples/** + - any-glob-to-any-file: + - examples/** + - tools/** devops: - changed-files: - any-glob-to-any-file: @@ -70,7 +72,7 @@ android: server: - changed-files: - any-glob-to-any-file: - - examples/server/** + - tools/server/** ggml: - changed-files: - any-glob-to-any-file: @@ -84,3 +86,10 @@ nix: embedding: - changed-files: - any-glob-to-any-file: examples/embedding/ + +Ascend NPU: + - changed-files: + - any-glob-to-any-file: + - ggml/include/ggml-cann.h + - ggml/src/ggml-cann/** + - docs/backend/CANN.md diff --git a/.github/workflows/bench.yml.disabled b/.github/workflows/bench.yml.disabled index 75d271479..f2d7e16e9 100644 --- a/.github/workflows/bench.yml.disabled +++ b/.github/workflows/bench.yml.disabled @@ -27,10 +27,10 @@ on: push: branches: - master - paths: ['llama.cpp', 'ggml.c', 'ggml-backend.cpp', 'ggml-quants.c', '**/*.cu', 'examples/server/*.h*', 'examples/server/*.cpp'] + paths: ['llama.cpp', 'ggml.c', 'ggml-backend.cpp', 'ggml-quants.c', '**/*.cu', 'tools/server/*.h*', 'tools/server/*.cpp'] pull_request_target: types: [opened, synchronize, reopened] - paths: ['llama.cpp', 'ggml.c', 'ggml-backend.cpp', 'ggml-quants.c', '**/*.cu', 'examples/server/*.h*', 'examples/server/*.cpp'] + paths: ['llama.cpp', 'ggml.c', 'ggml-backend.cpp', 'ggml-quants.c', '**/*.cu', 'tools/server/*.h*', 'tools/server/*.cpp'] schedule: - cron: '04 2 * * *' @@ -69,7 +69,7 @@ jobs: - name: Install python env id: pipenv run: | - cd examples/server/bench + cd tools/server/bench python3 -m venv venv source venv/bin/activate pip install -r requirements.txt @@ -79,7 +79,7 @@ jobs: run: | wget --quiet https://github.com/prometheus/prometheus/releases/download/v2.51.0/prometheus-2.51.0.linux-amd64.tar.gz tar xzf prometheus*.tar.gz --strip-components=1 - ./prometheus --config.file=examples/server/bench/prometheus.yml & + ./prometheus --config.file=tools/server/bench/prometheus.yml & while ! nc -z localhost 9090; do sleep 0.1 done @@ -92,7 +92,7 @@ jobs: - name: Install k6 and xk6-sse id: k6_installation run: | - cd examples/server/bench + cd tools/server/bench go install go.k6.io/xk6/cmd/xk6@latest xk6 build master \ --with github.com/phymbert/xk6-sse @@ -116,7 +116,7 @@ jobs: - name: Download the dataset id: download_dataset run: | - cd examples/server/bench + cd tools/server/bench wget --quiet https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json - name: Server bench @@ -126,7 +126,7 @@ jobs: run: | set -eux - cd examples/server/bench + cd tools/server/bench source venv/bin/activate python bench.py \ --runner-label ${{ env.RUNNER_LABEL }} \ @@ -157,9 +157,9 @@ jobs: name: bench-server-${{ github.job }}-${{ env.RUNNER_LABEL }}-${{ matrix.model }}-${{ matrix.ftype }} compression-level: 9 path: | - examples/server/bench/*.jpg - examples/server/bench/*.json - examples/server/bench/*.log + tools/server/bench/*.jpg + tools/server/bench/*.json + tools/server/bench/*.log - name: Commit status uses: Sibz/github-status-action@v1 @@ -178,17 +178,17 @@ jobs: with: client_id: ${{secrets.IMGUR_CLIENT_ID}} path: | - examples/server/bench/prompt_tokens_seconds.jpg - examples/server/bench/predicted_tokens_seconds.jpg - examples/server/bench/kv_cache_usage_ratio.jpg - examples/server/bench/requests_processing.jpg + tools/server/bench/prompt_tokens_seconds.jpg + tools/server/bench/predicted_tokens_seconds.jpg + tools/server/bench/kv_cache_usage_ratio.jpg + tools/server/bench/requests_processing.jpg - name: Extract mermaid id: set_mermaid run: | set -eux - cd examples/server/bench + cd tools/server/bench PROMPT_TOKENS_SECONDS=$(cat prompt_tokens_seconds.mermaid) echo "PROMPT_TOKENS_SECONDS<> $GITHUB_ENV echo "$PROMPT_TOKENS_SECONDS" >> $GITHUB_ENV diff --git a/.github/workflows/build-linux-cross.yml b/.github/workflows/build-linux-cross.yml index e8639913e..7cfc82ba4 100644 --- a/.github/workflows/build-linux-cross.yml +++ b/.github/workflows/build-linux-cross.yml @@ -4,29 +4,37 @@ on: workflow_call: jobs: - ubuntu-latest-riscv64-cpu-cross: - runs-on: ubuntu-latest + ubuntu-24-riscv64-cpu-cross: + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 - name: Setup Riscv run: | sudo dpkg --add-architecture riscv64 - sudo sed -i 's|http://azure.archive.ubuntu.com/ubuntu|http://ports.ubuntu.com/ubuntu-ports|g' \ - /etc/apt/sources.list /etc/apt/apt-mirrors.txt - sudo apt-get clean - sudo apt-get update + + # Add arch-specific repositories for non-amd64 architectures + cat << EOF | sudo tee /etc/apt/sources.list.d/riscv64-ports.list + deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe + deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe + deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe + deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe + EOF + + sudo apt-get update || true ;# Prevent failure due to missing URLs. + sudo apt-get install -y --no-install-recommends \ build-essential \ gcc-14-riscv64-linux-gnu \ - g++-14-riscv64-linux-gnu \ - libcurl4-openssl-dev:riscv64 + g++-14-riscv64-linux-gnu - name: Build run: | - cmake -B build -DCMAKE_BUILD_TYPE=Release \ + cmake -B build -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=Release \ -DGGML_OPENMP=OFF \ -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ -DLLAMA_BUILD_TESTS=OFF \ -DCMAKE_SYSTEM_NAME=Linux \ -DCMAKE_SYSTEM_PROCESSOR=riscv64 \ @@ -40,35 +48,40 @@ jobs: cmake --build build --config Release -j $(nproc) - ubuntu-latest-riscv64-vulkan-cross: - runs-on: ubuntu-latest + ubuntu-24-riscv64-vulkan-cross: + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - name: Setup Riscv run: | sudo dpkg --add-architecture riscv64 - sudo sed -i 's|http://azure.archive.ubuntu.com/ubuntu|http://ports.ubuntu.com/ubuntu-ports|g' \ - /etc/apt/sources.list /etc/apt/apt-mirrors.txt - sudo apt-get clean - sudo apt-get update + + # Add arch-specific repositories for non-amd64 architectures + cat << EOF | sudo tee /etc/apt/sources.list.d/riscv64-ports.list + deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe + deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe + deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe + deb [arch=riscv64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe + EOF + + sudo apt-get update || true ;# Prevent failure due to missing URLs. + sudo apt-get install -y --no-install-recommends \ build-essential \ glslc \ gcc-14-riscv64-linux-gnu \ g++-14-riscv64-linux-gnu \ - libvulkan-dev:riscv64 \ - libcurl4-openssl-dev:riscv64 + libvulkan-dev:riscv64 - name: Build run: | - cmake -B build -DCMAKE_BUILD_TYPE=Release \ + cmake -B build -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=Release \ -DGGML_VULKAN=ON \ -DGGML_OPENMP=OFF \ -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ -DLLAMA_BUILD_TESTS=OFF \ -DCMAKE_SYSTEM_NAME=Linux \ -DCMAKE_SYSTEM_PROCESSOR=riscv64 \ @@ -82,34 +95,39 @@ jobs: cmake --build build --config Release -j $(nproc) - ubuntu-latest-arm64-vulkan-cross: - runs-on: ubuntu-latest + ubuntu-24-arm64-vulkan-cross: + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - name: Setup Arm64 run: | sudo dpkg --add-architecture arm64 - sudo sed -i 's|http://azure.archive.ubuntu.com/ubuntu|http://ports.ubuntu.com/ubuntu-ports|g' \ - /etc/apt/sources.list /etc/apt/apt-mirrors.txt - sudo apt-get clean - sudo apt-get update + + # Add arch-specific repositories for non-amd64 architectures + cat << EOF | sudo tee /etc/apt/sources.list.d/arm64-ports.list + deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble main universe + deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe + deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe + deb [arch=arm64] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe + EOF + + sudo apt-get update || true ;# Prevent failure due to missing URLs. + sudo apt-get install -y --no-install-recommends \ build-essential \ glslc \ crossbuild-essential-arm64 \ - libvulkan-dev:arm64 \ - libcurl4-openssl-dev:arm64 + libvulkan-dev:arm64 - name: Build run: | - cmake -B build -DCMAKE_BUILD_TYPE=Release \ + cmake -B build -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=Release \ -DGGML_VULKAN=ON \ -DGGML_OPENMP=OFF \ -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ -DLLAMA_BUILD_TESTS=OFF \ -DCMAKE_SYSTEM_NAME=Linux \ -DCMAKE_SYSTEM_PROCESSOR=aarch64 \ @@ -122,3 +140,207 @@ jobs: -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH cmake --build build --config Release -j $(nproc) + + ubuntu-24-ppc64el-cpu-cross: + runs-on: ubuntu-24.04 + + steps: + - uses: actions/checkout@v4 + - name: Setup PowerPC64le + run: | + sudo dpkg --add-architecture ppc64el + + # Add arch-specific repositories for non-amd64 architectures + cat << EOF | sudo tee /etc/apt/sources.list.d/ppc64el-ports.list + deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble main universe + deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe + deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe + deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe + EOF + + sudo apt-get update || true ;# Prevent failure due to missing URLs. + + sudo apt-get install -y --no-install-recommends \ + build-essential \ + gcc-14-powerpc64le-linux-gnu \ + g++-14-powerpc64le-linux-gnu + + - name: Build + run: | + cmake -B build -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DCMAKE_SYSTEM_NAME=Linux \ + -DCMAKE_SYSTEM_PROCESSOR=ppc64 \ + -DCMAKE_C_COMPILER=powerpc64le-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=powerpc64le-linux-gnu-g++-14 \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + -DCMAKE_FIND_ROOT_PATH=/usr/lib/powerpc64le-linux-gnu \ + -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ + -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH + + cmake --build build --config Release -j $(nproc) + + ubuntu-24-ppc64el-vulkan-cross: + runs-on: ubuntu-24.04 + + steps: + - uses: actions/checkout@v4 + - name: Setup PowerPC64le + run: | + sudo dpkg --add-architecture ppc64el + + # Add arch-specific repositories for non-amd64 architectures + cat << EOF | sudo tee /etc/apt/sources.list.d/ppc64el-ports.list + deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble main universe + deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-updates main universe + deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-security main universe + deb [arch=ppc64el] http://ports.ubuntu.com/ubuntu-ports/ noble-backports main universe + EOF + + sudo apt-get update || true ;# Prevent failure due to missing URLs. + + sudo apt-get install -y --no-install-recommends \ + build-essential \ + glslc \ + gcc-14-powerpc64le-linux-gnu \ + g++-14-powerpc64le-linux-gnu \ + libvulkan-dev:ppc64el + + - name: Build + run: | + cmake -B build -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_VULKAN=ON \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DCMAKE_SYSTEM_NAME=Linux \ + -DCMAKE_SYSTEM_PROCESSOR=ppc64 \ + -DCMAKE_C_COMPILER=powerpc64le-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=powerpc64le-linux-gnu-g++-14 \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + -DCMAKE_FIND_ROOT_PATH=/usr/lib/powerpc64le-linux-gnu \ + -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ + -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH + + cmake --build build --config Release -j $(nproc) + + debian-13-loongarch64-cpu-cross: + runs-on: ubuntu-24.04 + container: debian@sha256:653dfb9f86c3782e8369d5f7d29bb8faba1f4bff9025db46e807fa4c22903671 + + steps: + - uses: actions/checkout@v4 + - name: Setup LoongArch + run: | + rm -f /etc/apt/sources.list.d/* + cat << EOF | tee /etc/apt/sources.list.d/debian-ports.list + deb http://snapshot.debian.org/archive/debian/20250515T202920Z/ trixie main + EOF + ( echo 'quiet "true";'; \ + echo 'APT::Get::Assume-Yes "true";'; \ + echo 'APT::Install-Recommends "false";'; \ + echo 'Acquire::Check-Valid-Until "false";'; \ + echo 'Acquire::Retries "5";'; \ + ) > /etc/apt/apt.conf.d/99snapshot-repos + + apt-get update + apt-get install -y ca-certificates debian-ports-archive-keyring cmake git zip + dpkg --add-architecture loong64 + + # Add arch-specific repositories for non-amd64 architectures + cat << EOF | tee /etc/apt/sources.list.d/loong64-ports.list + deb [arch=loong64] http://snapshot.debian.org/archive/debian-ports/20250515T194251Z/ sid main + EOF + + apt-get update || true ;# Prevent failure due to missing URLs. + + apt-get install -y --no-install-recommends \ + build-essential \ + gcc-14-loongarch64-linux-gnu \ + g++-14-loongarch64-linux-gnu + + - name: Build + run: | + cmake -B build -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DCMAKE_SYSTEM_NAME=Linux \ + -DCMAKE_SYSTEM_PROCESSOR=loongarch64 \ + -DCMAKE_C_COMPILER=loongarch64-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=loongarch64-linux-gnu-g++-14 \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + -DCMAKE_FIND_ROOT_PATH=/usr/lib/loongarch64-linux-gnu \ + -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ + -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH + + cmake --build build --config Release -j $(nproc) + + debian-13-loongarch64-vulkan-cross: + runs-on: ubuntu-24.04 + container: debian@sha256:653dfb9f86c3782e8369d5f7d29bb8faba1f4bff9025db46e807fa4c22903671 + + steps: + - uses: actions/checkout@v4 + - name: Setup LoongArch + run: | + rm -f /etc/apt/sources.list.d/* + cat << EOF | tee /etc/apt/sources.list.d/debian-ports.list + deb http://snapshot.debian.org/archive/debian/20250515T202920Z/ trixie main + EOF + ( echo 'quiet "true";'; \ + echo 'APT::Get::Assume-Yes "true";'; \ + echo 'APT::Install-Recommends "false";'; \ + echo 'Acquire::Check-Valid-Until "false";'; \ + echo 'Acquire::Retries "5";'; \ + ) > /etc/apt/apt.conf.d/99snapshot-repos + + apt-get update + apt-get install -y ca-certificates debian-ports-archive-keyring cmake git zip + dpkg --add-architecture loong64 + + # Add arch-specific repositories for non-amd64 architectures + cat << EOF | tee /etc/apt/sources.list.d/loong64-ports.list + deb [arch=loong64] http://snapshot.debian.org/archive/debian-ports/20250515T194251Z/ sid main + EOF + + apt-get update || true ;# Prevent failure due to missing URLs. + + apt-get install -y --no-install-recommends \ + build-essential \ + glslc \ + gcc-14-loongarch64-linux-gnu \ + g++-14-loongarch64-linux-gnu \ + libvulkan-dev:loong64 + + - name: Build + run: | + cmake -B build -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_VULKAN=ON \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DCMAKE_SYSTEM_NAME=Linux \ + -DCMAKE_SYSTEM_PROCESSOR=loongarch64 \ + -DCMAKE_C_COMPILER=loongarch64-linux-gnu-gcc-14 \ + -DCMAKE_CXX_COMPILER=loongarch64-linux-gnu-g++-14 \ + -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + -DCMAKE_FIND_ROOT_PATH=/usr/lib/loongarch64-linux-gnu \ + -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ + -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH + + cmake --build build --config Release -j $(nproc) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index bcfcf08ac..c4783a6df 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -2,30 +2,19 @@ name: CI on: workflow_dispatch: # allows manual triggering - inputs: - create_release: - description: 'Create new release' - required: true - type: boolean push: branches: - master - paths: ['.github/workflows/build.yml', '.github/workflows/build-linux-cross.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.cuh', '**/*.swift', '**/*.m', '**/*.metal', '**/*.comp'] + paths: ['.github/workflows/build.yml', '.github/workflows/build-linux-cross.yml', '**/CMakeLists.txt', '**/.cmake', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.cuh', '**/*.swift', '**/*.m', '**/*.metal', '**/*.comp'] pull_request: types: [opened, synchronize, reopened] - paths: ['.github/workflows/build.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.cuh', '**/*.swift', '**/*.m', '**/*.metal', '**/*.comp'] + paths: ['.github/workflows/build.yml', '.github/workflows/build-linux-cross.yml', '**/CMakeLists.txt', '**/.cmake', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.cuh', '**/*.swift', '**/*.m', '**/*.metal', '**/*.comp'] concurrency: group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} cancel-in-progress: true -# Fine-grant permission -# https://docs.github.com/en/actions/security-for-github-actions/security-guides/automatic-token-authentication#modifying-the-permissions-for-the-github_token -permissions: - contents: write # for creating release - env: - BRANCH_NAME: ${{ github.head_ref || github.ref_name }} GGML_NLOOP: 3 GGML_N_THREADS: 1 LLAMA_LOG_COLORS: 1 @@ -40,8 +29,6 @@ jobs: - name: Clone id: checkout uses: actions/checkout@v4 - with: - fetch-depth: 0 - name: ccache uses: hendrikmuhs/ccache-action@v1.2.16 @@ -74,33 +61,6 @@ jobs: cd build ctest -L 'main|curl' --verbose --timeout 900 - - name: Determine tag name - id: tag - shell: bash - run: | - BUILD_NUMBER="$(git rev-list --count HEAD)" - SHORT_HASH="$(git rev-parse --short=7 HEAD)" - if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then - echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT - else - SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') - echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT - fi - - - name: Pack artifacts - id: pack_artifacts - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - run: | - cp LICENSE ./build/bin/ - zip -r llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip ./build/bin/* - - - name: Upload artifacts - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - uses: actions/upload-artifact@v4 - with: - path: llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip - name: llama-bin-macos-arm64.zip - macOS-latest-cmake-x64: runs-on: macos-13 @@ -108,8 +68,6 @@ jobs: - name: Clone id: checkout uses: actions/checkout@v4 - with: - fetch-depth: 0 - name: ccache uses: hendrikmuhs/ccache-action@v1.2.16 @@ -143,33 +101,6 @@ jobs: cd build ctest -L main --verbose --timeout 900 - - name: Determine tag name - id: tag - shell: bash - run: | - BUILD_NUMBER="$(git rev-list --count HEAD)" - SHORT_HASH="$(git rev-parse --short=7 HEAD)" - if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then - echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT - else - SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') - echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT - fi - - - name: Pack artifacts - id: pack_artifacts - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - run: | - cp LICENSE ./build/bin/ - zip -r llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip ./build/bin/* - - - name: Upload artifacts - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - uses: actions/upload-artifact@v4 - with: - path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip - name: llama-bin-macos-x64.zip - ubuntu-cpu-cmake: strategy: matrix: @@ -185,8 +116,6 @@ jobs: - name: Clone id: checkout uses: actions/checkout@v4 - with: - fetch-depth: 0 - name: ccache uses: hendrikmuhs/ccache-action@v1.2.16 @@ -225,33 +154,6 @@ jobs: ./bin/llama-convert-llama2c-to-ggml --copy-vocab-from-model ./tok512.bin --llama2c-model stories260K.bin --llama2c-output-model stories260K.gguf ./bin/llama-cli -m stories260K.gguf -p "One day, Lily met a Shoggoth" -n 500 -c 256 - - name: Determine tag name - id: tag - shell: bash - run: | - BUILD_NUMBER="$(git rev-list --count HEAD)" - SHORT_HASH="$(git rev-parse --short=7 HEAD)" - if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then - echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT - else - SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') - echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT - fi - - - name: Pack artifacts - id: pack_artifacts - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - run: | - cp LICENSE ./build/bin/ - zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.zip ./build/bin/* - - - name: Upload artifacts - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - uses: actions/upload-artifact@v4 - with: - path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.zip - name: llama-bin-ubuntu-${{ matrix.build }}.zip - ubuntu-latest-cmake-sanitizer: runs-on: ubuntu-latest @@ -378,8 +280,6 @@ jobs: - name: Clone id: checkout uses: actions/checkout@v4 - with: - fetch-depth: 0 - name: ccache uses: hendrikmuhs/ccache-action@v1.2.16 @@ -406,35 +306,9 @@ jobs: id: cmake_test run: | cd build + export GGML_VK_VISIBLE_DEVICES=0 # This is using llvmpipe and runs slower than other backends - ctest -L main --verbose --timeout 2700 - - - name: Determine tag name - id: tag - shell: bash - run: | - BUILD_NUMBER="$(git rev-list --count HEAD)" - SHORT_HASH="$(git rev-parse --short=7 HEAD)" - if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then - echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT - else - SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') - echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT - fi - - - name: Pack artifacts - id: pack_artifacts - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - run: | - cp LICENSE ./build/bin/ - zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip ./build/bin/* - - - name: Upload artifacts - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - uses: actions/upload-artifact@v4 - with: - path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip - name: llama-bin-ubuntu-vulkan-x64.zip + ctest -L main --verbose --timeout 3600 ubuntu-22-cmake-hip: runs-on: ubuntu-22.04 @@ -478,7 +352,7 @@ jobs: ubuntu-22-cmake-musa: runs-on: ubuntu-22.04 - container: mthreads/musa:rc3.1.1-devel-ubuntu22.04 + container: mthreads/musa:rc4.0.1-mudnn-devel-ubuntu22.04 steps: - name: Clone @@ -633,6 +507,7 @@ jobs: -DGGML_METAL_EMBED_LIBRARY=ON \ -DLLAMA_BUILD_COMMON=OFF \ -DLLAMA_BUILD_EXAMPLES=OFF \ + -DLLAMA_BUILD_TOOLS=OFF \ -DLLAMA_BUILD_TESTS=OFF \ -DLLAMA_BUILD_SERVER=OFF \ -DCMAKE_SYSTEM_NAME=iOS \ @@ -669,6 +544,7 @@ jobs: -DGGML_METAL_EMBED_LIBRARY=ON \ -DLLAMA_BUILD_COMMON=OFF \ -DLLAMA_BUILD_EXAMPLES=OFF \ + -DLLAMA_BUILD_TOOLS=OFF \ -DLLAMA_BUILD_TESTS=OFF \ -DLLAMA_BUILD_SERVER=OFF \ -DCMAKE_SYSTEM_NAME=tvOS \ @@ -699,6 +575,7 @@ jobs: -DGGML_METAL_EMBED_LIBRARY=ON \ -DLLAMA_BUILD_COMMON=OFF \ -DLLAMA_BUILD_EXAMPLES=OFF \ + -DLLAMA_BUILD_TOOLS=OFF \ -DLLAMA_BUILD_TESTS=OFF \ -DLLAMA_BUILD_SERVER=OFF \ -DCMAKE_SYSTEM_NAME=visionOS \ @@ -739,6 +616,7 @@ jobs: -DGGML_METAL_EMBED_LIBRARY=ON \ -DLLAMA_CURL=OFF \ -DLLAMA_BUILD_EXAMPLES=OFF \ + -DLLAMA_BUILD_TOOLS=OFF \ -DLLAMA_BUILD_TESTS=OFF \ -DLLAMA_BUILD_SERVER=OFF \ -DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" @@ -767,7 +645,7 @@ jobs: uses: hendrikmuhs/ccache-action@v1.2.16 with: key: windows-msys2 - variant: sccache + variant: ccache evict-old-files: 1d - name: Setup ${{ matrix.sys }} @@ -810,39 +688,29 @@ jobs: strategy: matrix: include: - - build: 'noavx-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_AVX=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF' - - build: 'avx2-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON' - - build: 'avx-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_AVX2=OFF' - - build: 'avx512-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_AVX512=ON' + - build: 'cpu-x64 (static)' + defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DBUILD_SHARED_LIBS=OFF' - build: 'openblas-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DBLAS_INCLUDE_DIRS="$env:RUNNER_TEMP/openblas/include" -DBLAS_LIBRARIES="$env:RUNNER_TEMP/openblas/lib/openblas.lib"' - - build: 'kompute-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_KOMPUTE=ON -DKOMPUTE_OPT_DISABLE_VULKAN_VERSION_CHECK=ON' + defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_OPENMP=OFF -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS -DBLAS_INCLUDE_DIRS="$env:RUNNER_TEMP/openblas/include" -DBLAS_LIBRARIES="$env:RUNNER_TEMP/openblas/lib/openblas.lib"' - build: 'vulkan-x64' - defines: '-DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_VULKAN=ON' + defines: '-DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_VULKAN=ON' - build: 'llvm-arm64' defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON' - - build: 'msvc-arm64' - defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-msvc.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON' - build: 'llvm-arm64-opencl-adreno' defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" -DGGML_OPENCL=ON -DGGML_OPENCL_USE_ADRENO_KERNELS=ON' + # - build: 'kompute-x64' + # defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/x64-windows-llvm.cmake -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_OPENMP=OFF -DGGML_KOMPUTE=ON -DKOMPUTE_OPT_DISABLE_VULKAN_VERSION_CHECK=ON' steps: - name: Clone id: checkout uses: actions/checkout@v4 - with: - fetch-depth: 0 - name: ccache uses: hendrikmuhs/ccache-action@v1.2.16 with: key: windows-latest-cmake-${{ matrix.build }} - variant: sccache + variant: ccache evict-old-files: 1d - name: Clone Kompute submodule @@ -910,6 +778,7 @@ jobs: cmake -S . -B build ${{ matrix.defines }} ` -DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include" cmake --build build --config Release -j ${env:NUMBER_OF_PROCESSORS} + cp $env:CURL_PATH/bin/libcurl-*.dll build/bin/Release - name: Add libopenblas.dll id: add_libopenblas_dll @@ -918,68 +787,26 @@ jobs: cp $env:RUNNER_TEMP/openblas/bin/libopenblas.dll ./build/bin/Release/openblas.dll cp $env:RUNNER_TEMP/OpenBLAS.LICENSE.txt ./build/bin/Release/OpenBLAS-${env:OPENBLAS_VERSION}.txt - - name: Check AVX512F support - id: check_avx512f - if: ${{ matrix.build == 'avx512-x64' }} - continue-on-error: true - run: | - cd build - $vcdir = $(vswhere -latest -products * -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath) - $msvc = $(join-path $vcdir $('VC\Tools\MSVC\'+$(gc -raw $(join-path $vcdir 'VC\Auxiliary\Build\Microsoft.VCToolsVersion.default.txt')).Trim())) - $cl = $(join-path $msvc 'bin\Hostx64\x64\cl.exe') - echo 'int main(void){unsigned int a[4];__cpuid(a,7);return !(a[1]&65536);}' >> avx512f.c - & $cl /O2 /GS- /kernel avx512f.c /link /nodefaultlib /entry:main - .\avx512f.exe && echo "AVX512F: YES" && ( echo HAS_AVX512F=1 >> $env:GITHUB_ENV ) || echo "AVX512F: NO" - - name: Test id: cmake_test - # not all machines have native AVX-512 - if: ${{ matrix.build != 'msvc-arm64' && matrix.build != 'llvm-arm64' && matrix.build != 'llvm-arm64-opencl-adreno' && matrix.build != 'kompute-x64' && matrix.build != 'vulkan-x64' && (matrix.build != 'avx512-x64' || env.HAS_AVX512F == '1') }} + if: ${{ matrix.build != 'llvm-arm64' && matrix.build != 'llvm-arm64-opencl-adreno' }} run: | cd build ctest -L main -C Release --verbose --timeout 900 - - name: Test (Intel SDE) - id: cmake_test_sde - if: ${{ matrix.build == 'avx512-x64' && env.HAS_AVX512F == '0' }} # use Intel SDE for AVX-512 emulation - run: | - curl.exe -o $env:RUNNER_TEMP/sde.tar.xz -L "https://downloadmirror.intel.com/813591/sde-external-${env:SDE_VERSION}-win.tar.xz" - # for some weird reason windows tar doesn't like sde tar.xz - 7z x "-o${env:RUNNER_TEMP}" $env:RUNNER_TEMP/sde.tar.xz - 7z x "-o${env:RUNNER_TEMP}" $env:RUNNER_TEMP/sde.tar - $sde = $(join-path $env:RUNNER_TEMP sde-external-${env:SDE_VERSION}-win/sde.exe) - cd build - $env:LLAMA_SKIP_TESTS_SLOW_ON_EMULATOR = 1 - & $sde -future -- ctest -L main -C Release --verbose --timeout 900 - - - name: Determine tag name - id: tag - shell: bash - run: | - BUILD_NUMBER="$(git rev-list --count HEAD)" - SHORT_HASH="$(git rev-parse --short=7 HEAD)" - if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then - echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT - else - SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') - echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT - fi - - - name: Pack artifacts - id: pack_artifacts - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - env: - CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} - run: | - Copy-Item $env:CURL_PATH\bin\libcurl-x64.dll .\build\bin\Release\libcurl-x64.dll - 7z a llama-${{ steps.tag.outputs.name }}-bin-win-${{ matrix.build }}.zip .\build\bin\Release\* - - - name: Upload artifacts - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - uses: actions/upload-artifact@v4 - with: - path: llama-${{ steps.tag.outputs.name }}-bin-win-${{ matrix.build }}.zip - name: llama-bin-win-${{ matrix.build }}.zip + # TODO: disabled for now, consider adding tests for all CPU variants instead + # - name: Test (Intel SDE) + # id: cmake_test_sde + # if: ${{ matrix.build == 'avx512-x64' && env.HAS_AVX512F == '0' }} # use Intel SDE for AVX-512 emulation + # run: | + # curl.exe -o $env:RUNNER_TEMP/sde.tar.xz -L "https://downloadmirror.intel.com/813591/sde-external-${env:SDE_VERSION}-win.tar.xz" + # # for some weird reason windows tar doesn't like sde tar.xz + # 7z x "-o${env:RUNNER_TEMP}" $env:RUNNER_TEMP/sde.tar.xz + # 7z x "-o${env:RUNNER_TEMP}" $env:RUNNER_TEMP/sde.tar + # $sde = $(join-path $env:RUNNER_TEMP sde-external-${env:SDE_VERSION}-win/sde.exe) + # cd build + # $env:LLAMA_SKIP_TESTS_SLOW_ON_EMULATOR = 1 + # & $sde -future -- ctest -L main -C Release --verbose --timeout 900 ubuntu-latest-cmake-cuda: runs-on: ubuntu-latest @@ -989,8 +816,6 @@ jobs: - name: Clone id: checkout uses: actions/checkout@v4 - with: - fetch-depth: 0 - name: Install dependencies env: @@ -1016,83 +841,29 @@ jobs: -DGGML_CUDA=ON cmake --build build - windows-2019-cmake-cuda: - runs-on: windows-2019 + windows-2022-cmake-cuda: + runs-on: windows-2022 strategy: matrix: - cuda: ['12.4', '11.7'] - build: ['cuda'] + cuda: ['12.4'] steps: - name: Clone id: checkout uses: actions/checkout@v4 - with: - fetch-depth: 0 - name: Install ccache uses: hendrikmuhs/ccache-action@v1.2.16 with: - key: ${{ github.job }}-${{ matrix.cuda }}-${{ matrix.build }} - variant: sccache + key: windows-cuda-${{ matrix.cuda }} + variant: ccache evict-old-files: 1d - - name: Install Cuda Toolkit 11.7 - if: ${{ matrix.cuda == '11.7' }} - run: | - mkdir -p "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" - choco install unzip -y - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-11.7.99-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-11.7.99-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-11.7.99-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcublas/windows-x86_64/libcublas-windows-x86_64-11.7.4.6-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-11.7.91-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-11.7.91-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-11.7.101-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-11.7.91-archive.zip" - unzip '*.zip' -d "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" - xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_cudart-windows-x86_64-11.7.99-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y - xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_nvcc-windows-x86_64-11.7.99-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y - xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_nvrtc-windows-x86_64-11.7.99-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y - xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\libcublas-windows-x86_64-11.7.4.6-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y - xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_nvtx-windows-x86_64-11.7.91-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y - xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\visual_studio_integration-windows-x86_64-11.7.91-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y - xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_nvprof-windows-x86_64-11.7.101-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y - xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\cuda_cccl-windows-x86_64-11.7.91-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" /E /I /H /Y - echo "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - echo "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - echo "CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 - echo "CUDA_PATH_V11_7=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 - - - name: Install Cuda Toolkit 12.4 - if: ${{ matrix.cuda == '12.4' }} - run: | - mkdir -p "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" - choco install unzip -y - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-12.4.127-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-12.4.131-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-12.4.127-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/libcublas/windows-x86_64/libcublas-windows-x86_64-12.4.5.8-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-12.4.127-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_profiler_api/windows-x86_64/cuda_profiler_api-windows-x86_64-12.4.127-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-12.4.127-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-12.4.127-archive.zip" - curl -O "https://developer.download.nvidia.com/compute/cuda/redist/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-12.4.127-archive.zip" - unzip '*.zip' -d "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" - xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_cudart-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y - xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_nvcc-windows-x86_64-12.4.131-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y - xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_nvrtc-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y - xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\libcublas-windows-x86_64-12.4.5.8-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y - xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_nvtx-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y - xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_profiler_api-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y - xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\visual_studio_integration-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y - xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_nvprof-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y - xcopy "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\cuda_cccl-windows-x86_64-12.4.127-archive\*" "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" /E /I /H /Y - echo "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - echo "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - echo "CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 - echo "CUDA_PATH_V12_4=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.4" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 + - name: Install Cuda Toolkit + uses: ./.github/actions/windows-setup-cuda + with: + cuda_version: ${{ matrix.cuda }} - name: Install Ninja id: install_ninja @@ -1109,10 +880,12 @@ jobs: env: CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} run: | - call "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat" + call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" x64 cmake -S . -B build -G "Ninja Multi-Config" ^ -DLLAMA_BUILD_SERVER=ON ^ -DGGML_NATIVE=OFF ^ + -DGGML_BACKEND_DL=ON ^ + -DGGML_CPU_ALL_VARIANTS=ON ^ -DGGML_CUDA=ON ^ -DGGML_RPC=ON ^ -DCURL_LIBRARY="%CURL_PATH%/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="%CURL_PATH%/include" @@ -1120,51 +893,6 @@ jobs: cmake --build build --config Release -j %NINJA_JOBS% -t ggml cmake --build build --config Release - - name: Determine tag name - id: tag - shell: bash - run: | - BUILD_NUMBER="$(git rev-list --count HEAD)" - SHORT_HASH="$(git rev-parse --short=7 HEAD)" - if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then - echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT - else - SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') - echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT - fi - - - name: Pack artifacts - id: pack_artifacts - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - env: - CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} - run: | - cp $env:CURL_PATH\bin\libcurl-x64.dll .\build\bin\Release\libcurl-x64.dll - 7z a llama-${{ steps.tag.outputs.name }}-bin-win-${{ matrix.build }}-cu${{ matrix.cuda }}-x64.zip .\build\bin\Release\* - - - name: Upload artifacts - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - uses: actions/upload-artifact@v4 - with: - path: llama-${{ steps.tag.outputs.name }}-bin-win-${{ matrix.build }}-cu${{ matrix.cuda }}-x64.zip - name: llama-bin-win-cu${{ matrix.cuda }}-x64.zip - - - name: Copy and pack Cuda runtime - if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - run: | - echo "Cuda install location: ${{ env.CUDA_PATH }}" - $dst='.\build\bin\cudart\' - robocopy "${{env.CUDA_PATH}}\bin" $dst cudart64_*.dll cublas64_*.dll cublasLt64_*.dll - robocopy "${{env.CUDA_PATH}}\lib" $dst cudart64_*.dll cublas64_*.dll cublasLt64_*.dll - 7z a cudart-llama-bin-win-cu${{ matrix.cuda }}-x64.zip $dst\* - - - name: Upload Cuda runtime - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - uses: actions/upload-artifact@v4 - with: - path: cudart-llama-bin-win-cu${{ matrix.cuda }}-x64.zip - name: cudart-llama-bin-win-cu${{ matrix.cuda }}-x64.zip - windows-latest-cmake-sycl: runs-on: windows-latest @@ -1173,21 +901,19 @@ jobs: shell: bash env: - WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/b380d914-366b-4b77-a74a-05e3c38b3514/intel-oneapi-base-toolkit-2025.0.0.882_offline.exe + WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/7cd9bba0-7aab-4e30-b3ae-2221006a4a05/intel-oneapi-base-toolkit-2025.1.1.34_offline.exe WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel:intel.oneapi.win.dnnl:intel.oneapi.win.tbb.devel ONEAPI_ROOT: "C:/Program Files (x86)/Intel/oneAPI" steps: - name: Clone id: checkout uses: actions/checkout@v4 - with: - fetch-depth: 0 - name: ccache uses: hendrikmuhs/ccache-action@v1.2.16 with: key: windows-latest-cmake-sycl - variant: sccache + variant: ccache evict-old-files: 1d - name: Install @@ -1200,52 +926,6 @@ jobs: id: cmake_build run: examples/sycl/win-build-sycl.bat - - name: Determine tag name - id: tag - shell: bash - run: | - BUILD_NUMBER="$(git rev-list --count HEAD)" - SHORT_HASH="$(git rev-parse --short=7 HEAD)" - if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then - echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT - else - SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') - echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT - fi - - - name: Build the release package - id: pack_artifacts - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - run: | - echo "cp oneAPI running time dll files in ${{ env.ONEAPI_ROOT }} to ./build/bin" - - cp "${{ env.ONEAPI_ROOT }}/mkl/latest/bin/mkl_sycl_blas.5.dll" ./build/bin - cp "${{ env.ONEAPI_ROOT }}/mkl/latest/bin/mkl_core.2.dll" ./build/bin - cp "${{ env.ONEAPI_ROOT }}/mkl/latest/bin/mkl_tbb_thread.2.dll" ./build/bin - - cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_adapter_level_zero.dll" ./build/bin - cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_adapter_opencl.dll" ./build/bin - cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_loader.dll" ./build/bin - cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_win_proxy_loader.dll" ./build/bin - - cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/sycl8.dll" ./build/bin - cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/svml_dispmd.dll" ./build/bin - cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/libmmd.dll" ./build/bin - cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/libiomp5md.dll" ./build/bin - - cp "${{ env.ONEAPI_ROOT }}/dnnl/latest/bin/dnnl.dll" ./build/bin - cp "${{ env.ONEAPI_ROOT }}/tbb/latest/bin/tbb12.dll" ./build/bin - - echo "cp oneAPI running time dll files to ./build/bin done" - 7z a llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip ./build/bin/* - - - name: Upload the release package - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - uses: actions/upload-artifact@v4 - with: - path: llama-${{ steps.tag.outputs.name }}-bin-win-sycl-x64.zip - name: llama-bin-win-sycl-x64.zip - windows-latest-cmake-hip: if: ${{ github.event.inputs.create_release != 'true' }} runs-on: windows-latest @@ -1303,110 +983,12 @@ jobs: -DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include" cmake --build build -j ${env:NUMBER_OF_PROCESSORS} - # TODO: reuse windows-latest-cmake-hip instead of duplicating this job - windows-latest-cmake-hip-release: - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - runs-on: windows-latest - - strategy: - matrix: - gpu_target: [gfx1100, gfx1101, gfx1030] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Clone rocWMMA repository - id: clone_rocwmma - run: | - git clone https://github.com/rocm/rocwmma --branch rocm-6.2.4 --depth 1 - - - name: ccache - uses: hendrikmuhs/ccache-action@v1.2.16 - with: - key: windows-latest-cmake-hip-release - evict-old-files: 1d - - - name: Install - id: depends - run: | - $ErrorActionPreference = "Stop" - write-host "Downloading AMD HIP SDK Installer" - Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe" - write-host "Installing AMD HIP SDK" - Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait - write-host "Completed AMD HIP SDK installation" - - - name: Verify ROCm - id: verify - run: | - & 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version - - - name: libCURL - id: get_libcurl - uses: ./.github/actions/windows-setup-curl - - - name: Build - id: cmake_build - env: - CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} - run: | - $env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path) - $env:CMAKE_PREFIX_PATH="${env:HIP_PATH}" - cmake -G "Unix Makefiles" -B build -S . ` - -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" ` - -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" ` - -DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/rocwmma/library/include/" ` - -DCMAKE_BUILD_TYPE=Release ` - -DAMDGPU_TARGETS=${{ matrix.gpu_target }} ` - -DGGML_HIP_ROCWMMA_FATTN=ON ` - -DGGML_HIP=ON ` - -DGGML_RPC=ON ` - -DCURL_LIBRARY="$env:CURL_PATH/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="$env:CURL_PATH/include" - cmake --build build -j ${env:NUMBER_OF_PROCESSORS} - md "build\bin\rocblas\library\" - cp "${env:HIP_PATH}\bin\hipblas.dll" "build\bin\" - cp "${env:HIP_PATH}\bin\rocblas.dll" "build\bin\" - cp "${env:HIP_PATH}\bin\rocblas\library\*" "build\bin\rocblas\library\" - - - name: Determine tag name - id: tag - shell: bash - run: | - BUILD_NUMBER="$(git rev-list --count HEAD)" - SHORT_HASH="$(git rev-parse --short=7 HEAD)" - if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then - echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT - else - SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') - echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT - fi - - - name: Pack artifacts - id: pack_artifacts - env: - CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} - run: | - cp $env:CURL_PATH\bin\libcurl-x64.dll .\build\bin\libcurl-x64.dll - 7z a llama-${{ steps.tag.outputs.name }}-bin-win-hip-x64-${{ matrix.gpu_target }}.zip .\build\bin\* - - - name: Upload artifacts - uses: actions/upload-artifact@v4 - with: - path: llama-${{ steps.tag.outputs.name }}-bin-win-hip-x64-${{ matrix.gpu_target }}.zip - name: llama-bin-win-hip-x64-${{ matrix.gpu_target }}.zip - ios-xcode-build: runs-on: macos-latest steps: - name: Checkout code uses: actions/checkout@v4 - with: - fetch-depth: 0 - name: Build id: cmake_build @@ -1417,6 +999,7 @@ jobs: -DGGML_METAL_EMBED_LIBRARY=ON \ -DLLAMA_CURL=OFF \ -DLLAMA_BUILD_EXAMPLES=OFF \ + -DLLAMA_BUILD_TOOLS=OFF \ -DLLAMA_BUILD_TESTS=OFF \ -DLLAMA_BUILD_SERVER=OFF \ -DCMAKE_SYSTEM_NAME=iOS \ @@ -1432,32 +1015,6 @@ jobs: - name: Build Xcode project run: xcodebuild -project examples/llama.swiftui/llama.swiftui.xcodeproj -scheme llama.swiftui -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' FRAMEWORK_FOLDER_PATH=./build-ios build - - name: Determine tag name - id: tag - shell: bash - run: | - BUILD_NUMBER="$(git rev-list --count HEAD)" - SHORT_HASH="$(git rev-parse --short=7 HEAD)" - if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then - echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT - else - SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') - echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT - fi - - - name: Pack artifacts - id: pack_artifacts - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - run: | - zip --symlinks -r llama-${{ steps.tag.outputs.name }}-xcframework.zip build-apple/llama.xcframework - - - name: Upload artifacts - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - uses: actions/upload-artifact@v4 - with: - path: llama-${{ steps.tag.outputs.name }}-xcframework.zip - name: llama-${{ steps.tag.outputs.name }}-xcframework - android-build: runs-on: ubuntu-latest @@ -1485,297 +1042,23 @@ jobs: - name: Build run: | cd examples/llama.android - ./gradlew build --no-daemon - release: - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - - runs-on: ubuntu-latest - - needs: - - ubuntu-cpu-cmake - - ubuntu-22-cmake-vulkan - - windows-latest-cmake - - windows-2019-cmake-cuda - - windows-latest-cmake-sycl - - windows-latest-cmake-hip-release - - macOS-latest-cmake-arm64 - - macOS-latest-cmake-x64 - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: ccache - uses: hendrikmuhs/ccache-action@v1.2.16 - with: - key: release - evict-old-files: 1d - - - name: Determine tag name - id: tag - shell: bash - run: | - BUILD_NUMBER="$(git rev-list --count HEAD)" - SHORT_HASH="$(git rev-parse --short=7 HEAD)" - if [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then - echo "name=b${BUILD_NUMBER}" >> $GITHUB_OUTPUT - else - SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') - echo "name=${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" >> $GITHUB_OUTPUT - fi - - - name: Download artifacts - id: download-artifact - uses: actions/download-artifact@v4 - with: - path: ./artifact - - - name: Move artifacts - id: move_artifacts - run: mkdir -p ./artifact/release && mv ./artifact/*/*.zip ./artifact/release - - - name: Create release - id: create_release - uses: ggml-org/action-create-release@v1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - tag_name: ${{ steps.tag.outputs.name }} - - - name: Upload release - id: upload_release - uses: actions/github-script@v3 - with: - github-token: ${{secrets.GITHUB_TOKEN}} - script: | - const path = require('path'); - const fs = require('fs'); - const release_id = '${{ steps.create_release.outputs.id }}'; - for (let file of await fs.readdirSync('./artifact/release')) { - if (path.extname(file) === '.zip') { - console.log('uploadReleaseAsset', file); - await github.repos.uploadReleaseAsset({ - owner: context.repo.owner, - repo: context.repo.repo, - release_id: release_id, - name: file, - data: await fs.readFileSync(`./artifact/release/${file}`) - }); - } - } - -# ubuntu-latest-gcc: -# runs-on: ubuntu-latest -# -# strategy: -# matrix: -# build: [Debug, Release] -# -# steps: -# - name: Clone -# uses: actions/checkout@v4 -# -# - name: Dependencies -# run: | -# sudo apt-get update -# sudo apt-get install build-essential -# sudo apt-get install cmake -# -# - name: Configure -# run: cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }} -# -# - name: Build -# run: | -# make -# -# ubuntu-latest-clang: -# runs-on: ubuntu-latest -# -# strategy: -# matrix: -# build: [Debug, Release] -# -# steps: -# - name: Clone -# uses: actions/checkout@v4 -# -# - name: Dependencies -# run: | -# sudo apt-get update -# sudo apt-get install build-essential -# sudo apt-get install cmake -# -# - name: Configure -# run: cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang -# -# - name: Build -# run: | -# make -# -# ubuntu-latest-gcc-sanitized: -# runs-on: ubuntu-latest -# -# strategy: -# matrix: -# sanitizer: [ADDRESS, THREAD, UNDEFINED] -# -# steps: -# - name: Clone -# uses: actions/checkout@v4 -# -# - name: Dependencies -# run: | -# sudo apt-get update -# sudo apt-get install build-essential -# sudo apt-get install cmake -# -# - name: Configure -# run: cmake . -DCMAKE_BUILD_TYPE=Debug -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON -# -# - name: Build -# run: | -# make -# -# windows: -# runs-on: windows-latest -# -# strategy: -# matrix: -# build: [Release] -# arch: [Win32, x64] -# include: -# - arch: Win32 -# s2arc: x86 -# - arch: x64 -# s2arc: x64 -# -# steps: -# - name: Clone -# uses: actions/checkout@v4 -# -# - name: Add msbuild to PATH -# uses: microsoft/setup-msbuild@v1 -# -# - name: Configure -# run: > -# cmake -S . -B ./build -A ${{ matrix.arch }} -# -DCMAKE_BUILD_TYPE=${{ matrix.build }} -# -# - name: Build -# run: | -# cd ./build -# msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }} -# -# - name: Upload binaries -# uses: actions/upload-artifact@v4 -# with: -# name: llama-bin-${{ matrix.arch }} -# path: build/bin/${{ matrix.build }} -# -# windows-blas: -# runs-on: windows-latest -# -# strategy: -# matrix: -# build: [Release] -# arch: [Win32, x64] -# blas: [ON] -# include: -# - arch: Win32 -# obzip: https://github.com/xianyi/OpenBLAS/releases/download/v0.3.21/OpenBLAS-0.3.21-x86.zip -# s2arc: x86 -# - arch: x64 -# obzip: https://github.com/xianyi/OpenBLAS/releases/download/v0.3.21/OpenBLAS-0.3.21-x64.zip -# s2arc: x64 -# -# steps: -# - name: Clone -# uses: actions/checkout@v4 -# -# - name: Add msbuild to PATH -# uses: microsoft/setup-msbuild@v1 -# -# - name: Fetch OpenBLAS -# if: matrix.blas == 'ON' -# run: | -# C:/msys64/usr/bin/wget.exe -qO blas.zip ${{ matrix.obzip }} -# 7z x blas.zip -oblas -y -# copy blas/include/cblas.h . -# copy blas/include/openblas_config.h . -# echo "blasdir=$env:GITHUB_WORKSPACE/blas" >> $env:GITHUB_ENV -# -# - name: Configure -# run: > -# cmake -S . -B ./build -A ${{ matrix.arch }} -# -DCMAKE_BUILD_TYPE=${{ matrix.build }} -# -DLLAMA_SUPPORT_OPENBLAS=${{ matrix.blas }} -# -DCMAKE_LIBRARY_PATH="$env:blasdir/lib" -# -# - name: Build -# run: | -# cd ./build -# msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }} -# -# - name: Copy libopenblas.dll -# if: matrix.blas == 'ON' -# run: copy "$env:blasdir/bin/libopenblas.dll" build/bin/${{ matrix.build }} -# -# - name: Upload binaries -# if: matrix.blas == 'ON' -# uses: actions/upload-artifact@v4 -# with: -# name: llama-blas-bin-${{ matrix.arch }} -# path: build/bin/${{ matrix.build }} -# -# emscripten: -# runs-on: ubuntu-latest -# -# strategy: -# matrix: -# build: [Release] -# -# steps: -# - name: Clone -# uses: actions/checkout@v4 -# -# - name: Dependencies -# run: | -# wget -q https://github.com/emscripten-core/emsdk/archive/master.tar.gz -# tar -xvf master.tar.gz -# emsdk-master/emsdk update -# emsdk-master/emsdk install latest -# emsdk-master/emsdk activate latest -# -# - name: Configure -# run: echo "tmp" -# -# - name: Build -# run: | -# pushd emsdk-master -# source ./emsdk_env.sh -# popd -# emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }} -# make - openEuler-latest-cmake-cann: if: ${{ github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'Ascend NPU') }} defaults: run: - shell: bash -el {0} - runs-on: ubuntu-24.04-arm + shell: bash -el {0} strategy: matrix: + arch: [x86, aarch64] cann: - '8.1.RC1.alpha001-910b-openeuler22.03-py3.10' device: - 'ascend910b3' build: - 'Release' + runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-24.04' }} container: ascendai/cann:${{ matrix.cann }} steps: - name: Checkout diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 114cbf83e..2067927be 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -36,10 +36,13 @@ jobs: matrix: config: # Multi-stage build - - { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, free_disk_space: false } + # Note: the arm64 images are failing, which prevents the amd64 images from being built + # https://github.com/ggml-org/llama.cpp/issues/11888 + #- { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, free_disk_space: false } + - { tag: "cpu", dockerfile: ".devops/cpu.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false } - { tag: "cuda", dockerfile: ".devops/cuda.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false } - { tag: "musa", dockerfile: ".devops/musa.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true } - - { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false } + - { tag: "intel", dockerfile: ".devops/intel.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: true } - { tag: "vulkan", dockerfile: ".devops/vulkan.Dockerfile", platforms: "linux/amd64", full: true, light: true, server: true, free_disk_space: false } # Note: the rocm images are failing due to a compiler error and are disabled until this is fixed to allow the workflow to complete #- {tag: "rocm", dockerfile: ".devops/rocm.Dockerfile", platforms: "linux/amd64,linux/arm64", full: true, light: true, server: true, free_disk_space: true } diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 000000000..9874736cb --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,749 @@ +name: Release + +on: + workflow_dispatch: # allows manual triggering + inputs: + create_release: + description: 'Create new release' + required: true + type: boolean + push: + branches: + - master + paths: ['.github/workflows/release.yml', '**/CMakeLists.txt', '**/.cmake', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.cuh', '**/*.swift', '**/*.m', '**/*.metal', '**/*.comp'] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +env: + BRANCH_NAME: ${{ github.head_ref || github.ref_name }} + CMAKE_ARGS: "-DLLAMA_BUILD_EXAMPLES=OFF -DLLAMA_BUILD_TESTS=OFF -DLLAMA_BUILD_TOOLS=ON -DLLAMA_BUILD_SERVER=ON -DGGML_RPC=ON" + +jobs: + macOS-arm64: + runs-on: macos-14 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: macOS-latest-cmake-arm64 + evict-old-files: 1d + + - name: Dependencies + id: depends + continue-on-error: true + run: | + brew update + brew install curl + + - name: Build + id: cmake_build + run: | + sysctl -a + cmake -B build \ + -DCMAKE_BUILD_RPATH="@loader_path" \ + -DLLAMA_FATAL_WARNINGS=ON \ + -DGGML_METAL_USE_BF16=ON \ + -DGGML_METAL_EMBED_LIBRARY=ON \ + -DGGML_RPC=ON \ + ${{ env.CMAKE_ARGS }} + cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) + + - name: Determine tag name + id: tag + uses: ./.github/actions/get-tag-name + + - name: Pack artifacts + id: pack_artifacts + run: | + cp LICENSE ./build/bin/ + zip -r llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip ./build/bin/* + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + path: llama-${{ steps.tag.outputs.name }}-bin-macos-arm64.zip + name: llama-bin-macos-arm64.zip + + macOS-x64: + runs-on: macos-13 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: macOS-latest-cmake-x64 + evict-old-files: 1d + + - name: Dependencies + id: depends + continue-on-error: true + run: | + brew update + brew install curl + + - name: Build + id: cmake_build + run: | + sysctl -a + # Metal is disabled due to intermittent failures with Github runners not having a GPU: + # https://github.com/ggml-org/llama.cpp/actions/runs/8635935781/job/23674807267#step:5:2313 + cmake -B build \ + -DCMAKE_BUILD_RPATH="@loader_path" \ + -DLLAMA_FATAL_WARNINGS=ON \ + -DGGML_METAL=OFF \ + -DGGML_RPC=ON + cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) + + - name: Determine tag name + id: tag + uses: ./.github/actions/get-tag-name + + - name: Pack artifacts + id: pack_artifacts + run: | + cp LICENSE ./build/bin/ + zip -r llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip ./build/bin/* + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + path: llama-${{ steps.tag.outputs.name }}-bin-macos-x64.zip + name: llama-bin-macos-x64.zip + + ubuntu-22-cpu: + strategy: + matrix: + include: + - build: 'x64' + os: ubuntu-22.04 + # GGML_BACKEND_DL and GGML_CPU_ALL_VARIANTS are not currently supported on arm + # - build: 'arm64' + # os: ubuntu-22.04-arm + + runs-on: ${{ matrix.os }} + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ubuntu-cpu-cmake + evict-old-files: 1d + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install build-essential libcurl4-openssl-dev + + - name: Build + id: cmake_build + run: | + cmake -B build \ + -DGGML_BACKEND_DL=ON \ + -DGGML_NATIVE=OFF \ + -DGGML_CPU_ALL_VARIANTS=ON \ + -DLLAMA_FATAL_WARNINGS=ON \ + ${{ env.CMAKE_ARGS }} + cmake --build build --config Release -j $(nproc) + + - name: Determine tag name + id: tag + uses: ./.github/actions/get-tag-name + + - name: Pack artifacts + id: pack_artifacts + run: | + cp LICENSE ./build/bin/ + zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.zip ./build/bin/* + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-${{ matrix.build }}.zip + name: llama-bin-ubuntu-${{ matrix.build }}.zip + + ubuntu-22-vulkan: + runs-on: ubuntu-22.04 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ubuntu-22-cmake-vulkan + evict-old-files: 1d + + - name: Dependencies + id: depends + run: | + wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | sudo apt-key add - + sudo wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list https://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list + sudo apt-get update -y + sudo apt-get install -y build-essential mesa-vulkan-drivers vulkan-sdk libcurl4-openssl-dev + + - name: Build + id: cmake_build + run: | + cmake -B build \ + -DGGML_BACKEND_DL=ON \ + -DGGML_NATIVE=OFF \ + -DGGML_CPU_ALL_VARIANTS=ON \ + -DGGML_VULKAN=ON \ + ${{ env.CMAKE_ARGS }} + cmake --build build --config Release -j $(nproc) + + - name: Determine tag name + id: tag + uses: ./.github/actions/get-tag-name + + - name: Pack artifacts + id: pack_artifacts + run: | + cp LICENSE ./build/bin/ + zip -r llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip ./build/bin/* + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.zip + name: llama-bin-ubuntu-vulkan-x64.zip + + windows-cpu: + runs-on: windows-latest + + strategy: + matrix: + include: + - arch: 'x64' + - arch: 'arm64' + + steps: + - name: Clone + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: windows-latest-cmake-cpu-${{ matrix.arch }} + variant: ccache + evict-old-files: 1d + + - name: Install Ninja + run: | + choco install ninja + + - name: libCURL + id: get_libcurl + uses: ./.github/actions/windows-setup-curl + with: + architecture: ${{ matrix.arch == 'x64' && 'win64' || 'win64a' }} + + - name: Build + shell: cmd + env: + CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} + run: | + call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" ${{ matrix.arch }} + cmake -S . -B build -G "Ninja Multi-Config" ^ + -D CMAKE_TOOLCHAIN_FILE=cmake/${{ matrix.arch }}-windows-llvm.cmake ^ + -DGGML_NATIVE=OFF ^ + -DGGML_BACKEND_DL=ON ^ + -DGGML_CPU_ALL_VARIANTS=${{ matrix.arch == 'x64' && 'ON' || 'OFF' }} ^ + -DGGML_OPENMP=ON ^ + -DCURL_LIBRARY="%CURL_PATH%/lib/libcurl.dll.a" -DCURL_INCLUDE_DIR="%CURL_PATH%/include" ^ + ${{ env.CMAKE_ARGS }} + cmake --build build --config Release + + - name: Pack artifacts + id: pack_artifacts + env: + CURL_PATH: ${{ steps.get_libcurl.outputs.curl_path }} + run: | + Copy-Item $env:CURL_PATH\bin\libcurl-${{ matrix.arch }}.dll .\build\bin\Release\ + Copy-Item "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Redist\MSVC\14.42.34433\debug_nonredist\${{ matrix.arch }}\Microsoft.VC143.OpenMP.LLVM\libomp140.${{ matrix.arch == 'x64' && 'x86_64' || 'aarch64' }}.dll" .\build\bin\Release\ + 7z a llama-bin-win-cpu-${{ matrix.arch }}.zip .\build\bin\Release\* + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + path: llama-bin-win-cpu-${{ matrix.arch }}.zip + name: llama-bin-win-cpu-${{ matrix.arch }}.zip + + windows: + runs-on: windows-latest + + env: + OPENBLAS_VERSION: 0.3.23 + VULKAN_VERSION: 1.4.309.0 + + strategy: + matrix: + include: + - backend: 'vulkan' + arch: 'x64' + defines: '-DGGML_VULKAN=ON' + target: 'ggml-vulkan' + - backend: 'opencl-adreno' + arch: 'arm64' + defines: '-G "Ninja Multi-Config" -D CMAKE_TOOLCHAIN_FILE=cmake/arm64-windows-llvm.cmake -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" -DGGML_OPENCL=ON -DGGML_OPENCL_USE_ADRENO_KERNELS=ON' + target: 'ggml-opencl' + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: windows-latest-cmake-${{ matrix.backend }}-${{ matrix.arch }} + variant: ccache + evict-old-files: 1d + + - name: Install Vulkan SDK + id: get_vulkan + if: ${{ matrix.backend == 'vulkan' }} + run: | + curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/VulkanSDK-${env:VULKAN_VERSION}-Installer.exe" + & "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install + Add-Content $env:GITHUB_ENV "VULKAN_SDK=C:\VulkanSDK\${env:VULKAN_VERSION}" + Add-Content $env:GITHUB_PATH "C:\VulkanSDK\${env:VULKAN_VERSION}\bin" + + - name: Install Ninja + id: install_ninja + run: | + choco install ninja + + - name: Install OpenCL Headers and Libs + id: install_opencl + if: ${{ matrix.backend == 'opencl-adreno' && matrix.arch == 'arm64' }} + run: | + git clone https://github.com/KhronosGroup/OpenCL-Headers + cd OpenCL-Headers + cmake -B build ` + -DBUILD_TESTING=OFF ` + -DOPENCL_HEADERS_BUILD_TESTING=OFF ` + -DOPENCL_HEADERS_BUILD_CXX_TESTS=OFF ` + -DCMAKE_INSTALL_PREFIX="$env:RUNNER_TEMP/opencl-arm64-release" + cmake --build build --target install + git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader + cd OpenCL-ICD-Loader + cmake -B build-arm64-release ` + -A arm64 ` + -DCMAKE_PREFIX_PATH="$env:RUNNER_TEMP/opencl-arm64-release" ` + -DCMAKE_INSTALL_PREFIX="$env:RUNNER_TEMP/opencl-arm64-release" + cmake --build build-arm64-release --target install --config release + + - name: Build + id: cmake_build + run: | + cmake -S . -B build ${{ matrix.defines }} -DGGML_NATIVE=OFF -DGGML_CPU=OFF -DGGML_BACKEND_DL=ON -DLLAMA_CURL=OFF + cmake --build build --config Release --target ${{ matrix.target }} + + - name: Pack artifacts + id: pack_artifacts + run: | + 7z a llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip .\build\bin\Release\${{ matrix.target }}.dll + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + path: llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip + name: llama-bin-win-${{ matrix.backend }}-${{ matrix.arch }}.zip + + windows-cuda: + runs-on: windows-2022 + + strategy: + matrix: + cuda: ['12.4'] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Install ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: windows-cuda-${{ matrix.cuda }} + variant: ccache + evict-old-files: 1d + + - name: Install Cuda Toolkit + uses: ./.github/actions/windows-setup-cuda + with: + cuda_version: ${{ matrix.cuda }} + + - name: Install Ninja + id: install_ninja + run: | + choco install ninja + + - name: Build + id: cmake_build + shell: cmd + run: | + call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" x64 + cmake -S . -B build -G "Ninja Multi-Config" ^ + -DGGML_BACKEND_DL=ON ^ + -DGGML_NATIVE=OFF ^ + -DGGML_CPU=OFF ^ + -DGGML_CUDA=ON ^ + -DLLAMA_CURL=OFF + set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1 + cmake --build build --config Release -j %NINJA_JOBS% --target ggml-cuda + + - name: Pack artifacts + id: pack_artifacts + run: | + 7z a llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip .\build\bin\Release\ggml-cuda.dll + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + path: llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip + name: llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip + + - name: Copy and pack Cuda runtime + run: | + echo "Cuda install location: ${{ env.CUDA_PATH }}" + $dst='.\build\bin\cudart\' + robocopy "${{env.CUDA_PATH}}\bin" $dst cudart64_*.dll cublas64_*.dll cublasLt64_*.dll + robocopy "${{env.CUDA_PATH}}\lib" $dst cudart64_*.dll cublas64_*.dll cublasLt64_*.dll + 7z a cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip $dst\* + + - name: Upload Cuda runtime + uses: actions/upload-artifact@v4 + with: + path: cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip + name: cudart-llama-bin-win-cuda-${{ matrix.cuda }}-x64.zip + + windows-sycl: + runs-on: windows-latest + + defaults: + run: + shell: bash + + env: + WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/7cd9bba0-7aab-4e30-b3ae-2221006a4a05/intel-oneapi-base-toolkit-2025.1.1.34_offline.exe + WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel:intel.oneapi.win.dnnl:intel.oneapi.win.tbb.devel + ONEAPI_ROOT: "C:/Program Files (x86)/Intel/oneAPI" + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: windows-latest-cmake-sycl + variant: ccache + evict-old-files: 1d + + - name: Install + run: | + scripts/install-oneapi.bat $WINDOWS_BASEKIT_URL $WINDOWS_DPCPP_MKL + + - name: Build + id: cmake_build + shell: cmd + run: | + call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force + cmake -G "Ninja" -B build ^ + -DCMAKE_C_COMPILER=cl -DCMAKE_CXX_COMPILER=icx ^ + -DCMAKE_BUILD_TYPE=Release ^ + -DGGML_BACKEND_DL=ON -DBUILD_SHARED_LIBS=ON ^ + -DGGML_CPU=OFF -DGGML_SYCL=ON ^ + -DLLAMA_CURL=OFF + cmake --build build --target ggml-sycl -j + + - name: Build the release package + id: pack_artifacts + run: | + echo "cp oneAPI running time dll files in ${{ env.ONEAPI_ROOT }} to ./build/bin" + + cp "${{ env.ONEAPI_ROOT }}/mkl/latest/bin/mkl_sycl_blas.5.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/mkl/latest/bin/mkl_core.2.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/mkl/latest/bin/mkl_tbb_thread.2.dll" ./build/bin + + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_adapter_level_zero.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_adapter_opencl.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_loader.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/ur_win_proxy_loader.dll" ./build/bin + + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/sycl8.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/svml_dispmd.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/libmmd.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/compiler/latest/bin/libiomp5md.dll" ./build/bin + + cp "${{ env.ONEAPI_ROOT }}/dnnl/latest/bin/dnnl.dll" ./build/bin + cp "${{ env.ONEAPI_ROOT }}/tbb/latest/bin/tbb12.dll" ./build/bin + + echo "cp oneAPI running time dll files to ./build/bin done" + 7z a llama-bin-win-sycl-x64.zip ./build/bin/* + + - name: Upload the release package + uses: actions/upload-artifact@v4 + with: + path: llama-bin-win-sycl-x64.zip + name: llama-bin-win-sycl-x64.zip + + windows-hip: + runs-on: windows-latest + + strategy: + matrix: + include: + - name: "radeon" + gpu_targets: "gfx1100;gfx1101;gfx1102;gfx1030;gfx1031;gfx1032" + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Clone rocWMMA repository + id: clone_rocwmma + run: | + git clone https://github.com/rocm/rocwmma --branch rocm-6.2.4 --depth 1 + + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: windows-latest-cmake-hip-${{ matrix.name }}-x64 + evict-old-files: 1d + + - name: Install + id: depends + run: | + $ErrorActionPreference = "Stop" + write-host "Downloading AMD HIP SDK Installer" + Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe" + write-host "Installing AMD HIP SDK" + Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait + write-host "Completed AMD HIP SDK installation" + + - name: Verify ROCm + id: verify + run: | + & 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version + + - name: Build + id: cmake_build + run: | + $env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path) + $env:CMAKE_PREFIX_PATH="${env:HIP_PATH}" + cmake -G "Unix Makefiles" -B build -S . ` + -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" ` + -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" ` + -DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/rocwmma/library/include/ -Wno-ignored-attributes -Wno-nested-anon-types" ` + -DCMAKE_BUILD_TYPE=Release ` + -DGGML_BACKEND_DL=ON ` + -DGGML_NATIVE=OFF ` + -DGGML_CPU=OFF ` + -DAMDGPU_TARGETS="${{ matrix.gpu_targets }}" ` + -DGGML_HIP_ROCWMMA_FATTN=ON ` + -DGGML_HIP=ON ` + -DLLAMA_CURL=OFF + cmake --build build --target ggml-hip -j ${env:NUMBER_OF_PROCESSORS} + md "build\bin\rocblas\library\" + cp "${env:HIP_PATH}\bin\hipblas.dll" "build\bin\" + cp "${env:HIP_PATH}\bin\rocblas.dll" "build\bin\" + cp "${env:HIP_PATH}\bin\rocblas\library\*" "build\bin\rocblas\library\" + + - name: Pack artifacts + id: pack_artifacts + run: | + 7z a llama-bin-win-hip-${{ matrix.name }}-x64.zip .\build\bin\* + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + path: llama-bin-win-hip-${{ matrix.name }}-x64.zip + name: llama-bin-win-hip-${{ matrix.name }}-x64.zip + + ios-xcode-build: + runs-on: macos-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Build + id: cmake_build + run: | + sysctl -a + cmake -B build -G Xcode \ + -DGGML_METAL_USE_BF16=ON \ + -DGGML_METAL_EMBED_LIBRARY=ON \ + -DLLAMA_CURL=OFF \ + -DLLAMA_BUILD_EXAMPLES=OFF \ + -DLLAMA_BUILD_TOOLS=OFF \ + -DLLAMA_BUILD_TESTS=OFF \ + -DLLAMA_BUILD_SERVER=OFF \ + -DCMAKE_SYSTEM_NAME=iOS \ + -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \ + -DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml + cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO + + - name: xcodebuild for swift package + id: xcodebuild + run: | + ./build-xcframework.sh + + - name: Build Xcode project + run: xcodebuild -project examples/llama.swiftui/llama.swiftui.xcodeproj -scheme llama.swiftui -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' FRAMEWORK_FOLDER_PATH=./build-ios build + + - name: Determine tag name + id: tag + uses: ./.github/actions/get-tag-name + + - name: Pack artifacts + id: pack_artifacts + run: | + zip --symlinks -r llama-${{ steps.tag.outputs.name }}-xcframework.zip build-apple/llama.xcframework + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + path: llama-${{ steps.tag.outputs.name }}-xcframework.zip + name: llama-${{ steps.tag.outputs.name }}-xcframework + + release: + if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} + + # Fine-grant permission + # https://docs.github.com/en/actions/security-for-github-actions/security-guides/automatic-token-authentication#modifying-the-permissions-for-the-github_token + permissions: + contents: write # for creating release + + runs-on: ubuntu-latest + + needs: + - windows + - windows-cpu + - windows-cuda + - windows-sycl + - windows-hip + - ubuntu-22-cpu + - ubuntu-22-vulkan + - macOS-arm64 + - macOS-x64 + - ios-xcode-build + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Determine tag name + id: tag + uses: ./.github/actions/get-tag-name + + - name: Download artifacts + id: download-artifact + uses: actions/download-artifact@v4 + with: + path: ./artifact + merge-multiple: true + + - name: Move artifacts + id: move_artifacts + run: | + mkdir -p release + + echo "Adding CPU backend files to existing zips..." + for arch in x64 arm64; do + cpu_zip="artifact/llama-bin-win-cpu-${arch}.zip" + temp_dir=$(mktemp -d) + echo "Extracting CPU backend for $arch..." + unzip "$cpu_zip" -d "$temp_dir" + + echo "Adding CPU files to $arch zips..." + for target_zip in artifact/llama-bin-win-*-${arch}.zip; do + if [[ "$target_zip" == "$cpu_zip" ]]; then + continue + fi + echo "Adding CPU backend to $(basename "$target_zip")" + realpath_target_zip=$(realpath "$target_zip") + (cd "$temp_dir" && zip -r "$realpath_target_zip" .) + done + + rm -rf "$temp_dir" + done + + echo "Renaming and moving zips to release..." + for zip_file in artifact/llama-bin-win-*.zip; do + base_name=$(basename "$zip_file" .zip) + zip_name="llama-${{ steps.tag.outputs.name }}-${base_name#llama-}.zip" + echo "Moving $zip_file to release/$zip_name" + mv "$zip_file" "release/$zip_name" + done + + echo "Moving other artifacts..." + mv -v artifact/*.zip release + + - name: Create release + id: create_release + uses: ggml-org/action-create-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: ${{ steps.tag.outputs.name }} + + - name: Upload release + id: upload_release + uses: actions/github-script@v3 + with: + github-token: ${{secrets.GITHUB_TOKEN}} + script: | + const path = require('path'); + const fs = require('fs'); + const release_id = '${{ steps.create_release.outputs.id }}'; + for (let file of await fs.readdirSync('./release')) { + if (path.extname(file) === '.zip') { + console.log('uploadReleaseAsset', file); + await github.repos.uploadReleaseAsset({ + owner: context.repo.owner, + repo: context.repo.repo, + release_id: release_id, + name: file, + data: await fs.readFileSync(`./release/${file}`) + }); + } + } diff --git a/.github/workflows/server.yml b/.github/workflows/server.yml index 6c9b51322..f6da48857 100644 --- a/.github/workflows/server.yml +++ b/.github/workflows/server.yml @@ -15,10 +15,10 @@ on: push: branches: - master - paths: ['.github/workflows/server.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'examples/server/**.*'] + paths: ['.github/workflows/server.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'tools/server/**.*'] pull_request: types: [opened, synchronize, reopened] - paths: ['.github/workflows/server.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'examples/server/**.*'] + paths: ['.github/workflows/server.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'tools/server/**.*'] env: LLAMA_LOG_COLORS: 1 @@ -74,7 +74,7 @@ jobs: - name: Tests dependencies id: test_dependencies run: | - pip install -r examples/server/tests/requirements.txt + pip install -r tools/server/tests/requirements.txt # Setup nodejs (to be used for verifying bundled index.html) - uses: actions/setup-node@v4 @@ -84,14 +84,14 @@ jobs: - name: WebUI - Install dependencies id: webui_lint run: | - cd examples/server/webui + cd tools/server/webui npm ci - name: WebUI - Check code format id: webui_format run: | git config --global --add safe.directory $(realpath .) - cd examples/server/webui + cd tools/server/webui git status npm run format @@ -108,7 +108,7 @@ jobs: id: verify_server_index_html run: | git config --global --add safe.directory $(realpath .) - cd examples/server/webui + cd tools/server/webui git status npm run build @@ -161,26 +161,26 @@ jobs: env: GITHUB_ACTIONS: "true" run: | - cd examples/server/tests + cd tools/server/tests ./tests.sh - name: Tests (sanitizers) id: server_integration_tests_sanitizers if: ${{ matrix.sanitizer != '' }} run: | - cd examples/server/tests + cd tools/server/tests LLAMA_SANITIZE=1 ./tests.sh - name: Slow tests id: server_integration_tests_slow if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }} run: | - cd examples/server/tests + cd tools/server/tests SLOW_TESTS=1 ./tests.sh server-windows: - runs-on: windows-2019 + runs-on: windows-2022 steps: - name: Clone @@ -211,7 +211,7 @@ jobs: - name: Tests dependencies id: test_dependencies run: | - pip install -r examples/server/tests/requirements.txt + pip install -r tools/server/tests/requirements.txt - name: Copy Libcurl id: prepare_libcurl @@ -224,7 +224,7 @@ jobs: id: server_integration_tests if: ${{ !matrix.disabled_on_pr || !github.event.pull_request }} run: | - cd examples/server/tests + cd tools/server/tests $env:PYTHONIOENCODING = ":replace" pytest -v -x -m "not slow" @@ -232,6 +232,6 @@ jobs: id: server_integration_tests_slow if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }} run: | - cd examples/server/tests + cd tools/server/tests $env:SLOW_TESTS = "1" pytest -v -x diff --git a/.github/workflows/winget.yml b/.github/workflows/winget.yml new file mode 100644 index 000000000..5c2861559 --- /dev/null +++ b/.github/workflows/winget.yml @@ -0,0 +1,42 @@ +name: Update Winget Package + +on: + workflow_dispatch: # allows manual triggering + schedule: + - cron: '28 5 * * *' # Update every day at 5:28 UTC + +jobs: + update: + name: Update Winget Package + runs-on: ubuntu-latest + + steps: + - name: Install cargo binstall + uses: cargo-bins/cargo-binstall@268643a6b5ea099f5718ee5cd3ff7dc89a5eb49b + + - name: Install komac + run: | + cargo binstall komac@2.11.2 -y + + - name: Find latest release + id: find_latest_release + uses: actions/github-script@v6 + with: + script: | + const { data: releases } = await github.rest.repos.listReleases({ + owner: context.repo.owner, + repo: context.repo.repo, + }); + console.log("Latest release:", releases[0].tag_name); + return releases[0].tag_name; + + - name: Update manifest + env: + VERSION: ${{ steps.find_latest_release.outputs.result }} + run: | + echo "Updating manifest..." + komac update --version ${{ env.VERSION }} \ + --urls "https://github.com/ggml-org/llama.cpp/releases/download/${{ env.VERSION }}/llama-${{ env.VERSION }}-bin-win-vulkan-x64.zip" \ + --token ${{ secrets.WINGET_GITHUB_TOKEN }} \ + --submit \ + ggml.llamacpp diff --git a/.gitignore b/.gitignore index 2c67ad7f7..f8ceb1560 100644 --- a/.gitignore +++ b/.gitignore @@ -96,11 +96,11 @@ perf-*.txt # Examples examples/jeopardy/results.txt -examples/server/*.css.hpp -examples/server/*.html.hpp -examples/server/*.js.hpp -examples/server/*.mjs.hpp -examples/server/*.gz.hpp +tools/server/*.css.hpp +tools/server/*.html.hpp +tools/server/*.js.hpp +tools/server/*.mjs.hpp +tools/server/*.gz.hpp !build_64.sh !examples/*.bat !examples/*/*.kts @@ -110,7 +110,7 @@ examples/server/*.gz.hpp # Server Web UI temporary files node_modules -examples/server/webui/dist +tools/server/webui/dist # Python diff --git a/CMakeLists.txt b/CMakeLists.txt index de51c0a17..50801cdc6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -77,6 +77,7 @@ option(LLAMA_BUILD_COMMON "llama: build common utils library" ${LLAMA_STANDALONE # extra artifacts option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) +option(LLAMA_BUILD_TOOLS "llama: build tools" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE}) @@ -88,6 +89,14 @@ option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake) include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/common.cmake) +if (NOT DEFINED LLAMA_BUILD_NUMBER) + set(LLAMA_BUILD_NUMBER ${BUILD_NUMBER}) +endif() +if (NOT DEFINED LLAMA_BUILD_COMMIT) + set(LLAMA_BUILD_COMMIT ${BUILD_COMMIT}) +endif() +set(LLAMA_INSTALL_VERSION 0.0.${BUILD_NUMBER}) + # override ggml options set(GGML_ALL_WARNINGS ${LLAMA_ALL_WARNINGS}) set(GGML_FATAL_WARNINGS ${LLAMA_FATAL_WARNINGS}) @@ -154,10 +163,17 @@ if (LLAMA_USE_SYSTEM_GGML) endif() if (NOT TARGET ggml AND NOT LLAMA_USE_SYSTEM_GGML) + set(GGML_BUILD_NUMBER ${LLAMA_BUILD_NUMBER}) + set(GGML_BUILD_COMMIT ${LLAMA_BUILD_COMMIT}) add_subdirectory(ggml) # ... otherwise assume ggml is added by a parent CMakeLists.txt endif() +if (MINGW) + # Target Windows 8 for PrefetchVirtualMemory + add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER}) +endif() + # # build the library # @@ -187,6 +203,10 @@ if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_EXAMPLES) add_subdirectory(pocs) endif() +if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TOOLS) + add_subdirectory(tools) +endif() + # # install # @@ -194,10 +214,6 @@ endif() include(GNUInstallDirs) include(CMakePackageConfigHelpers) -set(LLAMA_BUILD_NUMBER ${BUILD_NUMBER}) -set(LLAMA_BUILD_COMMIT ${BUILD_COMMIT}) -set(LLAMA_INSTALL_VERSION 0.0.${BUILD_NUMBER}) - set(LLAMA_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header files") set(LLAMA_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files") set(LLAMA_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files") @@ -247,20 +263,3 @@ configure_file(cmake/llama.pc.in install(FILES "${CMAKE_CURRENT_BINARY_DIR}/llama.pc" DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig) - -# -# copy the license files -# - -# Check if running in GitHub Actions -if(DEFINED ENV{GITHUB_ACTIONS} AND "$ENV{GITHUB_ACTIONS}" STREQUAL "true") - message(STATUS "Running inside GitHub Actions - copying license files") - - # Copy all files from licenses/ to build/bin/ - file(GLOB LICENSE_FILES "${CMAKE_SOURCE_DIR}/licenses/*") - foreach(LICENSE_FILE ${LICENSE_FILES}) - get_filename_component(FILENAME ${LICENSE_FILE} NAME) - configure_file(${LICENSE_FILE} "${CMAKE_BINARY_DIR}/bin/${FILENAME}" COPYONLY) - endforeach() -endif() - diff --git a/CMakePresets.json b/CMakePresets.json index 13bdd7907..e98447013 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -38,15 +38,6 @@ } }, - { - "name": "arm64-windows-msvc", "hidden": true, - "architecture": { "value": "arm64", "strategy": "external" }, - "toolset": { "value": "host=x64", "strategy": "external" }, - "cacheVariables": { - "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-windows-msvc.cmake" - } - }, - { "name": "arm64-windows-llvm", "hidden": true, "architecture": { "value": "arm64", "strategy": "external" }, @@ -73,10 +64,6 @@ { "name": "arm64-apple-clang-release", "inherits": [ "base", "arm64-apple-clang", "reldbg" ] }, { "name": "arm64-apple-clang+static-release", "inherits": [ "base", "arm64-apple-clang", "reldbg", "static" ] }, - { "name": "arm64-windows-msvc-debug", "inherits": [ "base", "arm64-windows-msvc", "debug" ] }, - { "name": "arm64-windows-msvc-release", "inherits": [ "base", "arm64-windows-msvc", "reldbg" ] }, - { "name": "arm64-windows-msvc+static-release", "inherits": [ "base", "arm64-windows-msvc", "reldbg", "static" ] }, - { "name": "x64-windows-llvm-debug", "inherits": [ "base", "x64-windows-llvm", "debug" ] }, { "name": "x64-windows-llvm-release", "inherits": [ "base", "x64-windows-llvm", "release" ] }, { "name": "x64-windows-llvm-reldbg", "inherits": [ "base", "x64-windows-llvm", "reldbg" ] }, diff --git a/CODEOWNERS b/CODEOWNERS index 72d594b46..3186f8eb1 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -2,7 +2,7 @@ /ci/ @ggerganov /.devops/*.Dockerfile @ngxson -/examples/server/ @ngxson +/tools/server/ @ngxson /ggml/src/ggml-cuda/fattn* @JohannesGaessler /ggml/src/ggml-cuda/mmq.* @JohannesGaessler /ggml/src/ggml-cuda/mmv.* @JohannesGaessler diff --git a/Makefile b/Makefile index 1f9455eff..ac442aec0 100644 --- a/Makefile +++ b/Makefile @@ -367,7 +367,7 @@ ifdef LLAMA_SERVER_SSL endif ifndef GGML_NO_CPU_AARCH64 - MK_CPPFLAGS += -DGGML_USE_CPU_AARCH64 + MK_CPPFLAGS += -DGGML_USE_CPU_REPACK endif # warnings @@ -780,10 +780,6 @@ ifdef GGML_HIP MK_CPPFLAGS += -DGGML_USE_HIP -DGGML_USE_CUDA -ifdef GGML_HIP_UMA - MK_CPPFLAGS += -DGGML_HIP_UMA -endif # GGML_HIP_UMA - MK_LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib MK_LDFLAGS += -L$(ROCM_PATH)/lib64 -Wl,-rpath=$(ROCM_PATH)/lib64 MK_LDFLAGS += -lhipblas -lamdhip64 -lrocblas @@ -974,7 +970,7 @@ OBJ_GGML = \ $(DIR_GGML)/src/ggml-threading.o \ $(DIR_GGML)/src/ggml-cpu/ggml-cpu.o \ $(DIR_GGML)/src/ggml-cpu/ggml-cpu_cpp.o \ - $(DIR_GGML)/src/ggml-cpu/ggml-cpu-aarch64.o \ + $(DIR_GGML)/src/ggml-cpu/repack.o \ $(DIR_GGML)/src/ggml-cpu/ggml-cpu-hbm.o \ $(DIR_GGML)/src/ggml-cpu/ggml-cpu-quants.o \ $(DIR_GGML)/src/ggml-cpu/ggml-cpu-traits.o \ @@ -1160,10 +1156,10 @@ $(LIB_COMMON_S): $(OBJ_COMMON) # Clean generated server assets clean-server-assets: - find examples/server -type f -name "*.js.hpp" -delete - find examples/server -type f -name "*.mjs.hpp" -delete - find examples/server -type f -name "*.css.hpp" -delete - find examples/server -type f -name "*.html.hpp" -delete + find tools/server -type f -name "*.js.hpp" -delete + find tools/server -type f -name "*.mjs.hpp" -delete + find tools/server -type f -name "*.css.hpp" -delete + find tools/server -type f -name "*.html.hpp" -delete # Clean rule clean: clean-server-assets @@ -1183,7 +1179,7 @@ clean: clean-server-assets # Helper function that replaces .c, .cpp, and .cu file endings with .o: GET_OBJ_FILE = $(patsubst %.c,%.o,$(patsubst %.cpp,%.o,$(patsubst %.cu,%.o,$(1)))) -llama-cli: examples/main/main.cpp \ +llama-cli: tools/main/main.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) @@ -1191,12 +1187,7 @@ llama-cli: examples/main/main.cpp \ @echo '==== Run ./llama-cli -h for help. ====' @echo -llama-infill: examples/infill/infill.cpp \ - $(OBJ_ALL) - $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) - $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) - -llama-run: examples/run/run.cpp \ +llama-run: tools/run/run.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) @@ -1211,7 +1202,7 @@ llama-simple-chat: examples/simple-chat/simple-chat.cpp \ $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) -llama-tokenize: examples/tokenize/tokenize.cpp \ +llama-tokenize: tools/tokenize/tokenize.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) @@ -1221,27 +1212,27 @@ llama-batched: examples/batched/batched.cpp \ $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) -llama-batched-bench: examples/batched-bench/batched-bench.cpp \ +llama-batched-bench: tools/batched-bench/batched-bench.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) -llama-quantize: examples/quantize/quantize.cpp \ +llama-quantize: tools/quantize/quantize.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) -llama-quantize-stats: examples/quantize-stats/quantize-stats.cpp \ +llama-quantize-stats: tools/quantize-stats/quantize-stats.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) -llama-perplexity: examples/perplexity/perplexity.cpp \ +llama-perplexity: tools/perplexity/perplexity.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) -llama-imatrix: examples/imatrix/imatrix.cpp \ +llama-imatrix: tools/imatrix/imatrix.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) @@ -1283,7 +1274,7 @@ llama-gguf-hash: examples/gguf-hash/gguf-hash.cpp examples/gguf-hash/deps/sha1/s $(CXX) $(CXXFLAGS) -Iexamples/gguf-hash/deps -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) -llama-gguf-split: examples/gguf-split/gguf-split.cpp \ +llama-gguf-split: tools/gguf-split/gguf-split.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) @@ -1293,7 +1284,7 @@ llama-eval-callback: examples/eval-callback/eval-callback.cpp \ $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) -llama-cvector-generator: examples/cvector-generator/cvector-generator.cpp \ +llama-cvector-generator: tools/cvector-generator/cvector-generator.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) @@ -1303,12 +1294,12 @@ llama-convert-llama2c-to-ggml: examples/convert-llama2c-to-ggml/convert-llama2c- $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) -llama-bench: examples/llama-bench/llama-bench.cpp \ +llama-bench: tools/llama-bench/llama-bench.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) -llama-export-lora: examples/export-lora/export-lora.cpp \ +llama-export-lora: tools/export-lora/export-lora.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) @@ -1364,17 +1355,17 @@ llama-gbnf-validator: examples/gbnf-validator/gbnf-validator.cpp \ $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) ifdef GGML_RPC -rpc-server: examples/rpc/rpc-server.cpp \ +rpc-server: tools/rpc/rpc-server.cpp \ $(OBJ_GGML) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) endif # GGML_RPC llama-server: \ - examples/server/server.cpp \ - examples/server/utils.hpp \ - examples/server/httplib.h \ - examples/server/index.html.hpp \ - examples/server/loading.html.hpp \ + tools/server/server.cpp \ + tools/server/utils.hpp \ + tools/server/httplib.h \ + tools/server/index.html.hpp \ + tools/server/loading.html.hpp \ common/chat.cpp \ common/chat.h \ common/chat-template.hpp \ @@ -1382,10 +1373,10 @@ llama-server: \ common/minja.hpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) - $(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2) + $(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Itools/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2) -# Portable equivalent of `cd examples/server/public && xxd -i $(notdir $<) ../$(notdir $<).hpp`: -examples/server/%.hpp: examples/server/public/% FORCE Makefile +# Portable equivalent of `cd tools/server/public && xxd -i $(notdir $<) ../$(notdir $<).hpp`: +tools/server/%.hpp: tools/server/public/% FORCE Makefile @( export NAME=$(subst .,_,$(subst -,_,$(notdir $<))) && \ echo "unsigned char $${NAME}[] = {" && \ cat $< | od -v -t x1 -An | sed -E 's/([0-9a-fA-F]+)/0x\1, /g' && \ @@ -1398,36 +1389,36 @@ llama-gen-docs: examples/gen-docs/gen-docs.cpp \ $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) -libllava.a: examples/llava/llava.cpp \ - examples/llava/llava.h \ - examples/llava/clip.cpp \ - examples/llava/clip.h \ +libllava.a: tools/mtmd/llava.cpp \ + tools/mtmd/llava.h \ + tools/mtmd/clip.cpp \ + tools/mtmd/clip.h \ common/stb_image.h \ common/base64.hpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -static -fPIC -c $< -o $@ -Wno-cast-qual -llama-llava-cli: examples/llava/llava-cli.cpp \ - examples/llava/llava.cpp \ - examples/llava/llava.h \ - examples/llava/clip.cpp \ - examples/llava/clip.h \ +llama-llava-cli: tools/mtmd/llava-cli.cpp \ + tools/mtmd/llava.cpp \ + tools/mtmd/llava.h \ + tools/mtmd/clip.cpp \ + tools/mtmd/clip.h \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) $< $(filter-out %.h $<,$^) -o $@ $(LDFLAGS) -Wno-cast-qual -llama-minicpmv-cli: examples/llava/minicpmv-cli.cpp \ - examples/llava/llava.cpp \ - examples/llava/llava.h \ - examples/llava/clip.cpp \ - examples/llava/clip.h \ +llama-minicpmv-cli: tools/mtmd/minicpmv-cli.cpp \ + tools/mtmd/llava.cpp \ + tools/mtmd/llava.h \ + tools/mtmd/clip.cpp \ + tools/mtmd/clip.h \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) $< $(filter-out %.h $<,$^) -o $@ $(LDFLAGS) -Wno-cast-qual -llama-qwen2vl-cli: examples/llava/qwen2vl-cli.cpp \ - examples/llava/llava.cpp \ - examples/llava/llava.h \ - examples/llava/clip.cpp \ - examples/llava/clip.h \ +llama-qwen2vl-cli: tools/mtmd/qwen2vl-cli.cpp \ + tools/mtmd/llava.cpp \ + tools/mtmd/llava.h \ + tools/mtmd/clip.cpp \ + tools/mtmd/clip.h \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) $< $(filter-out %.h $<,$^) -o $@ $(LDFLAGS) -Wno-cast-qual @@ -1484,12 +1475,12 @@ tests/test-double-float: tests/test-double-float.cpp tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp \ $(OBJ_ALL) - $(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) -Itools/server -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) tests/test-chat: tests/test-chat.cpp \ $(OBJ_ALL) - $(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) -Itools/server -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) tests/test-opt: tests/test-opt.cpp \ diff --git a/README.md b/README.md index cf45f23cf..90c7364df 100644 --- a/README.md +++ b/README.md @@ -3,9 +3,10 @@ ![llama](https://user-images.githubusercontent.com/1991296/230134379-7181e485-c521-4d23-a0d6-f7b3b61ba524.png) [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](https://opensource.org/licenses/MIT) +[![Release](https://img.shields.io/github/v/release/ggml-org/llama.cpp)](https://github.com/ggml-org/llama.cpp/releases) [![Server](https://github.com/ggml-org/llama.cpp/actions/workflows/server.yml/badge.svg)](https://github.com/ggml-org/llama.cpp/actions/workflows/server.yml) -[Roadmap](https://github.com/users/ggerganov/projects/7) / [Project status](https://github.com/ggml-org/llama.cpp/discussions/3471) / [Manifesto](https://github.com/ggml-org/llama.cpp/discussions/205) / [ggml](https://github.com/ggml-org/ggml) +[Roadmap](https://github.com/users/ggerganov/projects/7) / [Manifesto](https://github.com/ggml-org/llama.cpp/discussions/205) / [ggml](https://github.com/ggml-org/ggml) Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) in pure C/C++ @@ -16,8 +17,9 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) ## Hot topics -- **How to use [MTLResidencySet](https://developer.apple.com/documentation/metal/mtlresidencyset?language=objc) to keep the GPU memory active?** https://github.com/ggml-org/llama.cpp/pull/11427 -- **VS Code extension for FIM completions:** https://github.com/ggml-org/llama.vscode +- 🔥 Multimodal support arrived in `llama-server`: [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) | [documentation](./docs/multimodal.md) +- A new binary `llama-mtmd-cli` is introduced to replace `llava-cli`, `minicpmv-cli`, `gemma3-cli` ([#13012](https://github.com/ggml-org/llama.cpp/pull/13012)) and `qwen2vl-cli` ([#13141](https://github.com/ggml-org/llama.cpp/pull/13141)), `libllava` will be deprecated +- VS Code extension for FIM completions: https://github.com/ggml-org/llama.vscode - Universal [tool call support](./docs/function-calling.md) in `llama-server` https://github.com/ggml-org/llama.cpp/pull/9639 - Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim - Introducing GGUF-my-LoRA https://github.com/ggml-org/llama.cpp/discussions/10123 @@ -26,6 +28,30 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) ---- +## Quick start + +Getting started with llama.cpp is straightforward. Here are several ways to install it on your machine: + +- Install `llama.cpp` using [brew, nix or winget](docs/install.md) +- Run with Docker - see our [Docker documentation](docs/docker.md) +- Download pre-built binaries from the [releases page](https://github.com/ggml-org/llama.cpp/releases) +- Build from source by cloning this repository - check out [our build guide](docs/build.md) + +Once installed, you'll need a model to work with. Head to the [Obtaining and quantizing models](#obtaining-and-quantizing-models) section to learn more. + +Example command: + +```sh +# Use a local model file +llama-cli -m my_model.gguf + +# Or download and run a model directly from Hugging Face +llama-cli -hf ggml-org/gemma-3-1b-it-GGUF + +# Launch OpenAI-compatible API server +llama-server -hf ggml-org/gemma-3-1b-it-GGUF +``` + ## Description The main goal of `llama.cpp` is to enable LLM inference with minimal setup and state-of-the-art performance on a wide @@ -35,7 +61,7 @@ range of hardware - locally and in the cloud. - Apple silicon is a first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks - AVX, AVX2, AVX512 and AMX support for x86 architectures - 1.5-bit, 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use -- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP and Moore Threads MTT GPUs via MUSA) +- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP and Moore Threads GPUs via MUSA) - Vulkan and SYCL backend support - CPU+GPU hybrid inference to partially accelerate models larger than the total VRAM capacity @@ -128,6 +154,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
Bindings +- Python: [ddh0/easy-llama](https://github.com/ddh0/easy-llama) - Python: [abetlen/llama-cpp-python](https://github.com/abetlen/llama-cpp-python) - Go: [go-skynet/go-llama.cpp](https://github.com/go-skynet/go-llama.cpp) - Node.js: [withcatai/node-llama-cpp](https://github.com/withcatai/node-llama-cpp) @@ -227,6 +254,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
+ ## Supported backends | Backend | Target devices | @@ -235,23 +263,13 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo | [BLAS](docs/build.md#blas-build) | All | | [BLIS](docs/backend/BLIS.md) | All | | [SYCL](docs/backend/SYCL.md) | Intel and Nvidia GPU | -| [MUSA](docs/build.md#musa) | Moore Threads MTT GPU | +| [MUSA](docs/build.md#musa) | Moore Threads GPU | | [CUDA](docs/build.md#cuda) | Nvidia GPU | | [HIP](docs/build.md#hip) | AMD GPU | | [Vulkan](docs/build.md#vulkan) | GPU | | [CANN](docs/build.md#cann) | Ascend NPU | | [OpenCL](docs/backend/OPENCL.md) | Adreno GPU | -| [RPC](https://github.com/ggml-org/llama.cpp/tree/master/examples/rpc) | All | - -## Building the project - -The main product of this project is the `llama` library. Its C-style interface can be found in [include/llama.h](include/llama.h). -The project also includes many example programs and tools using the `llama` library. The examples range from simple, minimal code snippets to sophisticated sub-projects such as an OpenAI-compatible HTTP server. Possible methods for obtaining the binaries: - -- Clone this repository and build locally, see [how to build](docs/build.md) -- On MacOS or Linux, install `llama.cpp` via [brew, flox or nix](docs/install.md) -- Use a Docker image, see [documentation for Docker](docs/docker.md) -- Download pre-built binaries from [releases](https://github.com/ggml-org/llama.cpp/releases) +| [RPC](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) | All | ## Obtaining and quantizing models @@ -260,7 +278,11 @@ The [Hugging Face](https://huggingface.co) platform hosts a [number of LLMs](htt - [Trending](https://huggingface.co/models?library=gguf&sort=trending) - [LLaMA](https://huggingface.co/models?sort=trending&search=llama+gguf) -You can either manually download the GGUF file or directly use any `llama.cpp`-compatible models from [Hugging Face](https://huggingface.co/) or other model hosting sites, such as [ModelScope](https://modelscope.cn/), by using this CLI argument: `-hf /[:quant]`. +You can either manually download the GGUF file or directly use any `llama.cpp`-compatible models from [Hugging Face](https://huggingface.co/) or other model hosting sites, such as [ModelScope](https://modelscope.cn/), by using this CLI argument: `-hf /[:quant]`. For example: + +```sh +llama-cli -hf ggml-org/gemma-3-1b-it-GGUF +``` By default, the CLI would download from Hugging Face, you can switch to other options with the environment variable `MODEL_ENDPOINT`. For example, you may opt to downloading model checkpoints from ModelScope or other model sharing communities by setting the environment variable, e.g. `MODEL_ENDPOINT=https://www.modelscope.cn/`. @@ -275,9 +297,9 @@ The Hugging Face platform provides a variety of online tools for converting, qua - Use the [GGUF-editor space](https://huggingface.co/spaces/CISCai/gguf-editor) to edit GGUF meta data in the browser (more info: https://github.com/ggml-org/llama.cpp/discussions/9268) - Use the [Inference Endpoints](https://ui.endpoints.huggingface.co/) to directly host `llama.cpp` in the cloud (more info: https://github.com/ggml-org/llama.cpp/discussions/9669) -To learn more about model quantization, [read this documentation](examples/quantize/README.md) +To learn more about model quantization, [read this documentation](tools/quantize/README.md) -## [`llama-cli`](examples/main) +## [`llama-cli`](tools/main) #### A CLI tool for accessing and experimenting with most of `llama.cpp`'s functionality. @@ -340,7 +362,7 @@ To learn more about model quantization, [read this documentation](examples/quant -## [`llama-server`](examples/server) +## [`llama-server`](tools/server) #### A lightweight, [OpenAI API](https://github.com/openai/openai-openapi) compatible, HTTP server for serving LLMs. @@ -410,7 +432,7 @@ To learn more about model quantization, [read this documentation](examples/quant -## [`llama-perplexity`](examples/perplexity) +## [`llama-perplexity`](tools/perplexity) #### A tool for measuring the perplexity [^1][^2] (and other quality metrics) of a model over a given text. @@ -435,10 +457,10 @@ To learn more about model quantization, [read this documentation](examples/quant -[^1]: [examples/perplexity/README.md](./examples/perplexity/README.md) +[^1]: [tools/perplexity/README.md](./tools/perplexity/README.md) [^2]: [https://huggingface.co/docs/transformers/perplexity](https://huggingface.co/docs/transformers/perplexity) -## [`llama-bench`](examples/llama-bench) +## [`llama-bench`](tools/llama-bench) #### Benchmark the performance of the inference for various parameters. @@ -459,7 +481,7 @@ To learn more about model quantization, [read this documentation](examples/quant -## [`llama-run`](examples/run) +## [`llama-run`](tools/run) #### A comprehensive example for running `llama.cpp` models. Useful for inferencing. Used with RamaLama [^3]. @@ -503,8 +525,8 @@ To learn more about model quantization, [read this documentation](examples/quant ## Other documentation -- [main (cli)](examples/main/README.md) -- [server](examples/server/README.md) +- [main (cli)](tools/main/README.md) +- [server](tools/server/README.md) - [GBNF grammars](grammars/README.md) #### Development documentation @@ -570,4 +592,12 @@ automatically. For example: $ echo "source ~/.llama-completion.bash" >> ~/.bashrc ``` -## References +## Dependencies + +- [yhirose/cpp-httplib](https://github.com/yhirose/cpp-httplib) - Single-header HTTP server, used by `llama-server` - MIT license +- [stb-image](https://github.com/nothings/stb) - Single-header image format decoder, used by multimodal subsystem - Public domain +- [nlohmann/json](https://github.com/nlohmann/json) - Single-header JSON library, used by various tools/examples - MIT License +- [minja](https://github.com/google/minja) - Minimal Jinja parser in C++, used by various tools/examples - MIT License +- [linenoise.cpp](./tools/run/linenoise.cpp/linenoise.cpp) - C++ library that provides readline-like line editing capabilities, used by `llama-run` - BSD 2-Clause License +- [curl](https://curl.se/) - Client-side URL transfer library, used by various tools/examples - [CURL License](https://curl.se/docs/copyright.html) +- [miniaudio.h](https://github.com/mackron/miniaudio) - Single-header audio format decoder, used by multimodal subsystem - Public domain diff --git a/SECURITY.md b/SECURITY.md index 6a1bb6c32..9749e95b7 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -40,7 +40,8 @@ To protect sensitive data from potential leaks or unauthorized access, it is cru ### Untrusted environments or networks If you can't run your models in a secure and isolated environment or if it must be exposed to an untrusted network, make sure to take the following security precautions: -* Confirm the hash of any downloaded artifact (e.g. pre-trained model weights) matches a known-good value +* Do not use the RPC backend, [rpc-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) and [llama-server](https://github.com/ggml-org/llama.cpp/tree/master/tools/server) functionality (see https://github.com/ggml-org/llama.cpp/pull/13061). +* Confirm the hash of any downloaded artifact (e.g. pre-trained model weights) matches a known-good value. * Encrypt your data if sending it over the network. ### Multi-Tenant environments diff --git a/build-xcframework.sh b/build-xcframework.sh index 97001b5f7..a08419a80 100755 --- a/build-xcframework.sh +++ b/build-xcframework.sh @@ -8,6 +8,7 @@ TVOS_MIN_OS_VERSION=16.4 BUILD_SHARED_LIBS=OFF LLAMA_BUILD_EXAMPLES=OFF +LLAMA_BUILD_TOOLS=OFF LLAMA_BUILD_TESTS=OFF LLAMA_BUILD_SERVER=OFF GGML_METAL=ON @@ -31,6 +32,7 @@ COMMON_CMAKE_ARGS=( -DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml -DBUILD_SHARED_LIBS=${BUILD_SHARED_LIBS} -DLLAMA_BUILD_EXAMPLES=${LLAMA_BUILD_EXAMPLES} + -DLLAMA_BUILD_TOOLS=${LLAMA_BUILD_TOOLS} -DLLAMA_BUILD_TESTS=${LLAMA_BUILD_TESTS} -DLLAMA_BUILD_SERVER=${LLAMA_BUILD_SERVER} -DGGML_METAL_EMBED_LIBRARY=${GGML_METAL_EMBED_LIBRARY} @@ -115,6 +117,7 @@ setup_framework_structure() { # Copy all required headers (common for all platforms) cp include/llama.h ${header_path} cp ggml/include/ggml.h ${header_path} + cp ggml/include/ggml-opt.h ${header_path} cp ggml/include/ggml-alloc.h ${header_path} cp ggml/include/ggml-backend.h ${header_path} cp ggml/include/ggml-metal.h ${header_path} diff --git a/ci/README.md b/ci/README.md index ec3f44350..6e297f1a8 100644 --- a/ci/README.md +++ b/ci/README.md @@ -54,7 +54,7 @@ docker run --privileged -it \ -v $HOME/llama.cpp/ci-cache:/ci-cache \ -v $HOME/llama.cpp/ci-results:/ci-results \ -v $PWD:/ws -w /ws \ - mthreads/musa:rc3.1.1-devel-ubuntu22.04 + mthreads/musa:rc4.0.1-mudnn-devel-ubuntu22.04 ``` Inside the container, execute the following commands: diff --git a/ci/run.sh b/ci/run.sh index f463d7a8b..940055705 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -39,14 +39,27 @@ sd=`dirname $0` cd $sd/../ SRC=`pwd` -CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=OFF" +CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON" if [ ! -z ${GG_BUILD_METAL} ]; then CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON -DGGML_METAL_USE_BF16=ON" fi if [ ! -z ${GG_BUILD_CUDA} ]; then - CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=native" + CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_CUDA=ON" + + if command -v nvidia-smi >/dev/null 2>&1; then + CUDA_ARCH=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits 2>/dev/null | head -1 | tr -d '.') + if [[ -n "$CUDA_ARCH" && "$CUDA_ARCH" =~ ^[0-9]+$ ]]; then + CMAKE_EXTRA="${CMAKE_EXTRA} -DCMAKE_CUDA_ARCHITECTURES=${CUDA_ARCH}" + else + echo "Warning: Using fallback CUDA architectures" + CMAKE_EXTRA="${CMAKE_EXTRA} -DCMAKE_CUDA_ARCHITECTURES=61;70;75;80;86;89" + fi + else + echo "Error: nvidia-smi not found, cannot build with CUDA" + exit 1 + fi fi if [ ! -z ${GG_BUILD_SYCL} ]; then @@ -187,8 +200,8 @@ function gg_run_test_scripts_debug { set -e - (cd ./examples/gguf-split && time bash tests.sh "$SRC/build-ci-debug/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log - (cd ./examples/quantize && time bash tests.sh "$SRC/build-ci-debug/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log + (cd ./tools/gguf-split && time bash tests.sh "$SRC/build-ci-debug/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log + (cd ./tools/quantize && time bash tests.sh "$SRC/build-ci-debug/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log set +e } @@ -211,8 +224,8 @@ function gg_run_test_scripts_release { set -e - (cd ./examples/gguf-split && time bash tests.sh "$SRC/build-ci-release/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log - (cd ./examples/quantize && time bash tests.sh "$SRC/build-ci-release/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log + (cd ./tools/gguf-split && time bash tests.sh "$SRC/build-ci-release/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log + (cd ./tools/quantize && time bash tests.sh "$SRC/build-ci-release/bin" "$MNT/models") 2>&1 | tee -a $OUT/${ci}-scripts.log set +e } diff --git a/cmake/arm64-windows-msvc.cmake b/cmake/arm64-windows-msvc.cmake deleted file mode 100644 index c77631420..000000000 --- a/cmake/arm64-windows-msvc.cmake +++ /dev/null @@ -1,6 +0,0 @@ -set( CMAKE_SYSTEM_NAME Windows ) -set( CMAKE_SYSTEM_PROCESSOR arm64 ) - -set( target arm64-pc-windows-msvc ) -set( CMAKE_C_COMPILER_TARGET ${target} ) -set( CMAKE_CXX_COMPILER_TARGET ${target} ) diff --git a/cmake/build-info.cmake b/cmake/build-info.cmake index c1a456e17..75c78222f 100644 --- a/cmake/build-info.cmake +++ b/cmake/build-info.cmake @@ -41,14 +41,20 @@ endif() if(MSVC) set(BUILD_COMPILER "${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}") - set(BUILD_TARGET ${CMAKE_VS_PLATFORM_NAME}) + if (CMAKE_VS_PLATFORM_NAME) + set(BUILD_TARGET ${CMAKE_VS_PLATFORM_NAME}) + else() + set(BUILD_TARGET "${CMAKE_SYSTEM_NAME} ${CMAKE_SYSTEM_PROCESSOR}") + endif() else() execute_process( - COMMAND sh -c "\"$@\" --version | head -1" _ ${CMAKE_C_COMPILER} + COMMAND ${CMAKE_C_COMPILER} --version OUTPUT_VARIABLE OUT OUTPUT_STRIP_TRAILING_WHITESPACE ) + string(REGEX REPLACE " *\n.*" "" OUT "${OUT}") set(BUILD_COMPILER ${OUT}) + execute_process( COMMAND ${CMAKE_C_COMPILER} -dumpmachine OUTPUT_VARIABLE OUT diff --git a/cmake/x64-windows-llvm.cmake b/cmake/x64-windows-llvm.cmake index 0603d738f..77e791407 100644 --- a/cmake/x64-windows-llvm.cmake +++ b/cmake/x64-windows-llvm.cmake @@ -3,9 +3,3 @@ set( CMAKE_SYSTEM_PROCESSOR x86_64 ) set( CMAKE_C_COMPILER clang ) set( CMAKE_CXX_COMPILER clang++ ) - -set( arch_c_flags "-march=native" ) - -set( CMAKE_C_FLAGS_INIT "${arch_c_flags}" ) -set( CMAKE_CXX_FLAGS_INIT "${arch_c_flags}" ) - diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 43533fc86..f43a630c9 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -7,8 +7,8 @@ llama_add_compile_flags() # Build info header # -if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/../.git") - set(GIT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../.git") +if(EXISTS "${PROJECT_SOURCE_DIR}/.git") + set(GIT_DIR "${PROJECT_SOURCE_DIR}/.git") # Is git submodule if(NOT IS_DIRECTORY "${GIT_DIR}") @@ -18,34 +18,26 @@ if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/../.git") if (SLASH_POS EQUAL 0) set(GIT_DIR "${REAL_GIT_DIR}") else() - set(GIT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../${REAL_GIT_DIR}") + set(GIT_DIR "${PROJECT_SOURCE_DIR}/${REAL_GIT_DIR}") endif() endif() if(EXISTS "${GIT_DIR}/index") - set(GIT_INDEX "${GIT_DIR}/index") + # For build-info.cpp below + set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS "${GIT_DIR}/index") else() message(WARNING "Git index not found in git repository.") - set(GIT_INDEX "") endif() else() message(WARNING "Git repository not found; to enable automatic generation of build info, make sure Git is installed and the project is a Git repository.") - set(GIT_INDEX "") endif() -# Add a custom command to rebuild build-info.cpp when .git/index changes -add_custom_command( - OUTPUT "${CMAKE_CURRENT_SOURCE_DIR}/build-info.cpp" - COMMENT "Generating build details from Git" - COMMAND ${CMAKE_COMMAND} -DMSVC=${MSVC} -DCMAKE_C_COMPILER_VERSION=${CMAKE_C_COMPILER_VERSION} - -DCMAKE_C_COMPILER_ID=${CMAKE_C_COMPILER_ID} -DCMAKE_VS_PLATFORM_NAME=${CMAKE_VS_PLATFORM_NAME} - -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -P "${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info-gen-cpp.cmake" - WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/.." - DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/build-info.cpp.in" ${GIT_INDEX} - VERBATIM -) +set(TEMPLATE_FILE "${CMAKE_CURRENT_SOURCE_DIR}/build-info.cpp.in") +set(OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/build-info.cpp") +configure_file(${TEMPLATE_FILE} ${OUTPUT_FILE}) + set(TARGET build_info) -add_library(${TARGET} OBJECT build-info.cpp) +add_library(${TARGET} OBJECT ${OUTPUT_FILE}) if (BUILD_SHARED_LIBS) set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON) endif() @@ -56,21 +48,24 @@ add_library(${TARGET} STATIC arg.cpp arg.h base64.hpp + chat-parser.cpp + chat-parser.h chat.cpp chat.h common.cpp common.h console.cpp console.h + json-partial.cpp + json-partial.h json-schema-to-grammar.cpp - json.hpp llguidance.cpp log.cpp log.h - minja/chat-template.hpp - minja/minja.hpp ngram-cache.cpp ngram-cache.h + regex-partial.cpp + regex-partial.h sampling.cpp sampling.h speculative.cpp @@ -117,8 +112,8 @@ if (LLAMA_LLGUIDANCE) ExternalProject_Add(llguidance_ext GIT_REPOSITORY https://github.com/guidance-ai/llguidance - # v0.7.10: - GIT_TAG 0309d2a6bf40abda35344a362edc71e06d5009f8 + # v0.7.20 (+ fix to build on GCC 15): + GIT_TAG b5b8b64dba11c4e4ee6b1d1450d3a3ae279891e8 PREFIX ${CMAKE_BINARY_DIR}/llguidance SOURCE_DIR ${LLGUIDANCE_SRC} BUILD_IN_SOURCE TRUE @@ -139,6 +134,30 @@ if (LLAMA_LLGUIDANCE) set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance ${LLGUIDANCE_PLATFORM_LIBS}) endif () -target_include_directories(${TARGET} PUBLIC .) +target_include_directories(${TARGET} PUBLIC . ../vendor) target_compile_features (${TARGET} PUBLIC cxx_std_17) target_link_libraries (${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads) + + +# +# copy the license files +# + +# Check if running in GitHub Actions +if (DEFINED ENV{GITHUB_ACTIONS} AND "$ENV{GITHUB_ACTIONS}" STREQUAL "true") + message(STATUS "Running inside GitHub Actions - copying license files") + + # Copy all files from licenses/ to build/bin/ + file(GLOB LICENSE_FILES "${CMAKE_SOURCE_DIR}/licenses/*") + foreach(LICENSE_FILE ${LICENSE_FILES}) + get_filename_component(FILENAME ${LICENSE_FILE} NAME) + add_custom_command( + POST_BUILD + TARGET ${TARGET} + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${LICENSE_FILE}" + "$/${FILENAME}" + COMMENT "Copying ${FILENAME} to ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}") + message(STATUS "Copying ${LICENSE_FILE} to ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${FILENAME}") + endforeach() +endif() diff --git a/common/arg.cpp b/common/arg.cpp index 0b57f9da1..231de227a 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1,10 +1,11 @@ -#include "gguf.h" // for reading GGUF splits #include "arg.h" +#include "chat.h" #include "common.h" +#include "gguf.h" // for reading GGUF splits +#include "json-schema-to-grammar.h" #include "log.h" #include "sampling.h" -#include "chat.h" // fix problem with std::min and std::max #if defined(_WIN32) @@ -15,6 +16,9 @@ #include #endif +#define JSON_ASSERT GGML_ASSERT +#include + #include #include #include @@ -34,10 +38,32 @@ #include #endif -#include "json-schema-to-grammar.h" - using json = nlohmann::ordered_json; +std::initializer_list mmproj_examples = { + LLAMA_EXAMPLE_MTMD, + LLAMA_EXAMPLE_SERVER, +}; + +static std::string read_file(const std::string & fname) { + std::ifstream file(fname); + if (!file) { + throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str())); + } + std::string content((std::istreambuf_iterator(file)), std::istreambuf_iterator()); + file.close(); + return content; +} + +static void write_file(const std::string & fname, const std::string & content) { + std::ofstream file(fname); + if (!file) { + throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str())); + } + file << content; + file.close(); +} + common_arg & common_arg::set_examples(std::initializer_list examples) { this->examples = std::move(examples); return *this; @@ -157,6 +183,10 @@ struct common_hf_file_res { #ifdef LLAMA_USE_CURL +bool common_has_curl() { + return true; +} + #ifdef __linux__ #include #elif defined(_WIN32) @@ -189,11 +219,11 @@ struct curl_slist_ptr { #define CURL_MAX_RETRY 3 #define CURL_RETRY_DELAY_SECONDS 2 -static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds) { +static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds, const char * method_name) { int remaining_attempts = max_attempts; while (remaining_attempts > 0) { - LOG_INF("%s: Trying to download from %s (attempt %d of %d)...\n", __func__ , url.c_str(), max_attempts - remaining_attempts + 1, max_attempts); + LOG_INF("%s: %s %s (attempt %d of %d)...\n", __func__ , method_name, url.c_str(), max_attempts - remaining_attempts + 1, max_attempts); CURLcode res = curl_easy_perform(curl); if (res == CURLE_OK) { @@ -204,6 +234,7 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma LOG_WRN("%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay); remaining_attempts--; + if (remaining_attempts == 0) break; std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay)); } @@ -213,7 +244,56 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma } // download one single file from remote URL to local path -static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token) { +static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token, bool offline) { + // Check if the file already exists locally + auto file_exists = std::filesystem::exists(path); + + // If the file exists, check its JSON metadata companion file. + std::string metadata_path = path + ".json"; + nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead + std::string etag; + std::string last_modified; + + if (file_exists) { + if (offline) { + LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str()); + return true; // skip verification/downloading + } + // Try and read the JSON metadata file (note: stream autoclosed upon exiting this block). + std::ifstream metadata_in(metadata_path); + if (metadata_in.good()) { + try { + metadata_in >> metadata; + LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str()); + if (metadata.contains("etag") && metadata.at("etag").is_string()) { + etag = metadata.at("etag"); + } + if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) { + last_modified = metadata.at("lastModified"); + } + } catch (const nlohmann::json::exception & e) { + LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); + } + } + // if we cannot open the metadata file, we assume that the downloaded file is not valid (etag and last-modified are left empty, so we will download it again) + } else { + if (offline) { + LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str()); + return false; + } + LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); + } + + // Send a HEAD request to retrieve the etag and last-modified headers + struct common_load_model_from_url_headers { + std::string etag; + std::string last_modified; + }; + + common_load_model_from_url_headers headers; + bool head_request_ok = false; + bool should_download = !file_exists; // by default, we should download if the file does not exist + // Initialize libcurl curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); curl_slist_ptr http_headers; @@ -222,8 +302,6 @@ static bool common_download_file_single(const std::string & url, const std::stri return false; } - bool force_download = false; - // Set the URL, allow to follow http redirection curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); @@ -242,97 +320,54 @@ static bool common_download_file_single(const std::string & url, const std::stri curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); #endif - // Check if the file already exists locally - auto file_exists = std::filesystem::exists(path); + typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *); + auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t { + common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata; - // If the file exists, check its JSON metadata companion file. - std::string metadata_path = path + ".json"; - nlohmann::json metadata; - std::string etag; - std::string last_modified; + static std::regex header_regex("([^:]+): (.*)\r\n"); + static std::regex etag_regex("ETag", std::regex_constants::icase); + static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase); - if (file_exists) { - // Try and read the JSON metadata file (note: stream autoclosed upon exiting this block). - std::ifstream metadata_in(metadata_path); - if (metadata_in.good()) { - try { - metadata_in >> metadata; - LOG_INF("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str()); - if (metadata.contains("url") && metadata.at("url").is_string()) { - auto previous_url = metadata.at("url").get(); - if (previous_url != url) { - LOG_ERR("%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str()); - return false; - } - } - if (metadata.contains("etag") && metadata.at("etag").is_string()) { - etag = metadata.at("etag"); - } - if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) { - last_modified = metadata.at("lastModified"); - } - } catch (const nlohmann::json::exception & e) { - LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); - return false; + std::string header(buffer, n_items); + std::smatch match; + if (std::regex_match(header, match, header_regex)) { + const std::string & key = match[1]; + const std::string & value = match[2]; + if (std::regex_match(key, match, etag_regex)) { + headers->etag = value; + } else if (std::regex_match(key, match, last_modified_regex)) { + headers->last_modified = value; } } - } else { - LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); - } - - // Send a HEAD request to retrieve the etag and last-modified headers - struct common_load_model_from_url_headers { - std::string etag; - std::string last_modified; + return n_items; }; - common_load_model_from_url_headers headers; + curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress + curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast(header_callback)); + curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers); - { - typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *); - auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t { - common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata; - - static std::regex header_regex("([^:]+): (.*)\r\n"); - static std::regex etag_regex("ETag", std::regex_constants::icase); - static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase); - - std::string header(buffer, n_items); - std::smatch match; - if (std::regex_match(header, match, header_regex)) { - const std::string & key = match[1]; - const std::string & value = match[2]; - if (std::regex_match(key, match, etag_regex)) { - headers->etag = value; - } else if (std::regex_match(key, match, last_modified_regex)) { - headers->last_modified = value; - } - } - return n_items; - }; - - curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb - curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress - curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast(header_callback)); - curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers); - - bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS); - if (!was_perform_successful) { - return false; - } - - long http_code = 0; - curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code); - if (http_code != 200) { - // HEAD not supported, we don't know if the file has changed - // force trigger downloading - force_download = true; - LOG_ERR("%s: HEAD invalid http status code received: %ld\n", __func__, http_code); - } + // we only allow retrying once for HEAD requests + // this is for the use case of using running offline (no internet), retrying can be annoying + bool was_perform_successful = curl_perform_with_retry(url, curl.get(), 1, 0, "HEAD"); + if (!was_perform_successful) { + head_request_ok = false; } - bool should_download = !file_exists || force_download; - if (!should_download) { + long http_code = 0; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code); + if (http_code == 200) { + head_request_ok = true; + } else { + LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code); + head_request_ok = false; + } + + // if head_request_ok is false, we don't have the etag or last-modified headers + // we leave should_download as-is, which is true if the file does not exist + if (head_request_ok) { + // check if ETag or Last-Modified headers are different + // if it is, we need to download the file again if (!etag.empty() && etag != headers.etag) { LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str()); should_download = true; @@ -341,6 +376,7 @@ static bool common_download_file_single(const std::string & url, const std::stri should_download = true; } } + if (should_download) { std::string path_temporary = path + ".downloadInProgress"; if (file_exists) { @@ -394,7 +430,7 @@ static bool common_download_file_single(const std::string & url, const std::stri // start the download LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__, llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str()); - bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS); + bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS, "GET"); if (!was_perform_successful) { return false; } @@ -415,13 +451,15 @@ static bool common_download_file_single(const std::string & url, const std::stri {"etag", headers.etag}, {"lastModified", headers.last_modified} }); - std::ofstream(metadata_path) << metadata.dump(4); - LOG_INF("%s: file metadata saved: %s\n", __func__, metadata_path.c_str()); + write_file(metadata_path, metadata.dump(4)); + LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str()); if (rename(path_temporary.c_str(), path.c_str()) != 0) { LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); return false; } + } else { + LOG_INF("%s: using cached file: %s\n", __func__, path.c_str()); } return true; @@ -429,12 +467,12 @@ static bool common_download_file_single(const std::string & url, const std::stri // download multiple files from remote URLs to local paths // the input is a vector of pairs -static bool common_download_file_multiple(const std::vector> & urls, const std::string & bearer_token) { +static bool common_download_file_multiple(const std::vector> & urls, const std::string & bearer_token, bool offline) { // Prepare download in parallel std::vector> futures_download; for (auto const & item : urls) { - futures_download.push_back(std::async(std::launch::async, [bearer_token](const std::pair & it) -> bool { - return common_download_file_single(it.first, it.second, bearer_token); + futures_download.push_back(std::async(std::launch::async, [bearer_token, offline](const std::pair & it) -> bool { + return common_download_file_single(it.first, it.second, bearer_token, offline); }, item)); } @@ -450,14 +488,15 @@ static bool common_download_file_multiple(const std::vector> common_remote_get_content(const std::string & url, const common_remote_params & params) { + curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); + curl_slist_ptr http_headers; + std::vector res_buffer; + + curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); + curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L); + typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data); + auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t { + auto data_vec = static_cast *>(data); + data_vec->insert(data_vec->end(), (char *)ptr, (char *)ptr + size * nmemb); + return size * nmemb; + }; + curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_buffer); +#if defined(_WIN32) + curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); +#endif + if (params.timeout > 0) { + curl_easy_setopt(curl.get(), CURLOPT_TIMEOUT, params.timeout); + } + if (params.max_size > 0) { + curl_easy_setopt(curl.get(), CURLOPT_MAXFILESIZE, params.max_size); + } + http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); + for (const auto & header : params.headers) { + http_headers.ptr = curl_slist_append(http_headers.ptr, header.c_str()); + } + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); + + CURLcode res = curl_easy_perform(curl.get()); + + if (res != CURLE_OK) { + std::string error_msg = curl_easy_strerror(res); + throw std::runtime_error("error: cannot make GET request: " + error_msg); + } + + long res_code; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code); + + return { res_code, std::move(res_buffer) }; +} + /** * Allow getting the HF file from the HF repo with tag (like ollama), for example: * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4 @@ -533,7 +616,7 @@ static bool common_download_model( * * Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files. */ -static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token) { +static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token, bool offline) { auto parts = string_split(hf_repo_with_tag, ':'); std::string tag = parts.size() > 1 ? parts.back() : "latest"; std::string hf_repo = parts[0]; @@ -541,46 +624,53 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_ throw std::invalid_argument("error: invalid HF repo format, expected /[:quant]\n"); } - // fetch model info from Hugging Face Hub API - curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); - curl_slist_ptr http_headers; - std::string res_str; + std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag; - std::string model_endpoint = get_model_endpoint(); - - std::string url = model_endpoint + "v2/" + hf_repo + "/manifests/" + tag; - curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); - curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); - typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data); - auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t { - static_cast(data)->append((char * ) ptr, size * nmemb); - return size * nmemb; - }; - curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); - curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str); -#if defined(_WIN32) - curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); -#endif + // headers + std::vector headers; + headers.push_back("Accept: application/json"); if (!bearer_token.empty()) { - std::string auth_header = "Authorization: Bearer " + bearer_token; - http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); + headers.push_back("Authorization: Bearer " + bearer_token); } // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response - http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); - http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json"); - curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); + // User-Agent header is already set in common_remote_get_content, no need to set it here - CURLcode res = curl_easy_perform(curl.get()); + // we use "=" to avoid clashing with other component, while still being allowed on windows + std::string cached_response_fname = "manifest=" + hf_repo + "=" + tag + ".json"; + string_replace_all(cached_response_fname, "/", "_"); + std::string cached_response_path = fs_get_cache_file(cached_response_fname); - if (res != CURLE_OK) { - throw std::runtime_error("error: cannot make GET request to HF API"); + // make the request + common_remote_params params; + params.headers = headers; + long res_code = 0; + std::string res_str; + bool use_cache = false; + if (!offline) { + try { + auto res = common_remote_get_content(url, params); + res_code = res.first; + res_str = std::string(res.second.data(), res.second.size()); + } catch (const std::exception & e) { + LOG_WRN("error: failed to get manifest at %s: %s\n", url.c_str(), e.what()); + } } + if (res_code == 0) { + if (std::filesystem::exists(cached_response_path)) { + LOG_WRN("trying to read manifest from cache: %s\n", cached_response_path.c_str()); + res_str = read_file(cached_response_path); + res_code = 200; + use_cache = true; + } else { + throw std::runtime_error( + offline ? "error: failed to get manifest (offline mode)" + : "error: failed to get manifest (check your internet connection)"); + } + } + std::string ggufFile; + std::string mmprojFile; - long res_code; - std::string ggufFile = ""; - std::string mmprojFile = ""; - curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code); - if (res_code == 200) { + if (res_code == 200 || res_code == 304) { // extract ggufFile.rfilename in json, using regex { std::regex pattern("\"ggufFile\"[\\s\\S]*?\"rfilename\"\\s*:\\s*\"([^\"]+)\""); @@ -597,6 +687,10 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_ mmprojFile = match[1].str(); } } + if (!use_cache) { + // if not using cached response, update the cache file + write_file(cached_response_path, res_str); + } } else if (res_code == 401) { throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token"); } else { @@ -613,51 +707,75 @@ static struct common_hf_file_res common_get_hf_file(const std::string & hf_repo_ #else -static bool common_download_file_single(const std::string &, const std::string &, const std::string &) { +bool common_has_curl() { + return false; +} + +static bool common_download_file_single(const std::string &, const std::string &, const std::string &, bool) { LOG_ERR("error: built without CURL, cannot download model from internet\n"); return false; } -static bool common_download_file_multiple(const std::vector> &, const std::string &) { +static bool common_download_file_multiple(const std::vector> &, const std::string &, bool) { LOG_ERR("error: built without CURL, cannot download model from the internet\n"); return false; } static bool common_download_model( const common_params_model &, - const std::string &) { + const std::string &, + bool) { LOG_ERR("error: built without CURL, cannot download model from the internet\n"); return false; } -static struct common_hf_file_res common_get_hf_file(const std::string &, const std::string &) { +static struct common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool) { LOG_ERR("error: built without CURL, cannot download model from the internet\n"); return {}; } +std::pair> common_remote_get_content(const std::string & url, const common_remote_params &) { + if (!url.empty()) { + throw std::runtime_error("error: built without CURL, cannot download model from the internet"); + } + + return {}; +} + #endif // LLAMA_USE_CURL // // utils // -static void common_params_handle_model( +struct handle_model_result { + bool found_mmproj = false; + common_params_model mmproj; +}; + +static handle_model_result common_params_handle_model( struct common_params_model & model, const std::string & bearer_token, const std::string & model_path_default, - bool is_mmproj = false) { // TODO: move is_mmproj to an enum when we have more files? + bool offline) { + handle_model_result result; // handle pre-fill default model path and url based on hf_repo and hf_file { if (!model.hf_repo.empty()) { // short-hand to avoid specifying --hf-file -> default it to --model if (model.hf_file.empty()) { if (model.path.empty()) { - auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token); + auto auto_detected = common_get_hf_file(model.hf_repo, bearer_token, offline); if (auto_detected.repo.empty() || auto_detected.ggufFile.empty()) { exit(1); // built without CURL, error message already printed } model.hf_repo = auto_detected.repo; - model.hf_file = is_mmproj ? auto_detected.mmprojFile : auto_detected.ggufFile; + model.hf_file = auto_detected.ggufFile; + if (!auto_detected.mmprojFile.empty()) { + result.found_mmproj = true; + result.mmproj.hf_repo = model.hf_repo; + result.mmproj.hf_file = auto_detected.mmprojFile; + } } else { model.hf_file = model.path; } @@ -688,12 +806,14 @@ static void common_params_handle_model( // then, download it if needed if (!model.url.empty()) { - bool ok = common_download_model(model, bearer_token); + bool ok = common_download_model(model, bearer_token, offline); if (!ok) { LOG_ERR("error: failed to download model from %s\n", model.url.c_str()); exit(1); } } + + return result; } const std::vector kv_cache_types = { @@ -827,16 +947,25 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n"); } - common_params_handle_model(params.model, params.hf_token, DEFAULT_MODEL_PATH); - common_params_handle_model(params.speculative.model, params.hf_token, ""); - common_params_handle_model(params.vocoder.model, params.hf_token, ""); - - // allow --mmproj to be set from -hf - // assuming that mmproj is always in the same repo as text model - if (!params.model.hf_repo.empty() && ctx_arg.ex == LLAMA_EXAMPLE_LLAVA) { - params.mmproj.hf_repo = params.model.hf_repo; + // handle model and download + { + auto res = common_params_handle_model(params.model, params.hf_token, DEFAULT_MODEL_PATH, params.offline); + if (params.no_mmproj) { + params.mmproj = {}; + } else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) { + // optionally, handle mmproj model when -hf is specified + params.mmproj = res.mmproj; + } + // only download mmproj if the current example is using it + for (auto & ex : mmproj_examples) { + if (ctx_arg.ex == ex) { + common_params_handle_model(params.mmproj, params.hf_token, "", params.offline); + break; + } + } + common_params_handle_model(params.speculative.model, params.hf_token, "", params.offline); + common_params_handle_model(params.vocoder.model, params.hf_token, "", params.offline); } - common_params_handle_model(params.mmproj, params.hf_token, "", true); if (params.escape) { string_process_escapes(params.prompt); @@ -859,10 +988,6 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context params.tensor_buft_overrides.push_back({nullptr, nullptr}); } - if (params.reranking && params.embedding) { - throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both"); - } - if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) { throw std::runtime_error(string_format( "error: the supplied chat template is not supported: %s%s\n", @@ -968,7 +1093,6 @@ static void common_params_print_completion(common_params_context & ctx_arg) { "llama-embedding", "llama-eval-callback", "llama-export-lora", - "llama-gbnf-validator", "llama-gen-docs", "llama-gguf", "llama-gguf-hash", @@ -976,20 +1100,18 @@ static void common_params_print_completion(common_params_context & ctx_arg) { "llama-gritlm", "llama-imatrix", "llama-infill", - "llama-llava-cli", + "llama-mtmd-cli", "llama-llava-clip-quantize-cli", "llama-lookahead", "llama-lookup", "llama-lookup-create", "llama-lookup-merge", "llama-lookup-stats", - "llama-minicpmv-cli", "llama-parallel", "llama-passkey", "llama-perplexity", "llama-q8dot", "llama-quantize", - "llama-quantize-stats", "llama-qwen2vl-cli", "llama-retrieval", "llama-run", @@ -1078,6 +1200,9 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e fprintf(stderr, "%s\n", ex.what()); ctx_arg.params = params_org; return false; + } catch (std::exception & ex) { + fprintf(stderr, "%s\n", ex.what()); + exit(1); // for other exceptions, we exit with status code 1 } return true; @@ -1169,7 +1294,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.use_color = true; } - ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_INFILL, LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP})); + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP})); add_opt(common_arg( {"-t", "--threads"}, "N", string_format("number of threads to use during generation (default: %d)", params.cpuparams.n_threads), @@ -1219,9 +1344,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex )); add_opt(common_arg( {"--prio"}, "N", - string_format("set process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.cpuparams.priority), + string_format("set process/thread priority : low(-1), normal(0), medium(1), high(2), realtime(3) (default: %d)\n", params.cpuparams.priority), [](common_params & params, int prio) { - if (prio < 0 || prio > 3) { + if (prio < GGML_SCHED_PRIO_LOW || prio > GGML_SCHED_PRIO_REALTIME) { throw std::invalid_argument("invalid value"); } params.cpuparams.priority = (enum ggml_sched_priority) prio; @@ -1302,7 +1427,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex add_opt(common_arg( {"-n", "--predict", "--n-predict"}, "N", string_format( - ex == LLAMA_EXAMPLE_MAIN || ex == LLAMA_EXAMPLE_INFILL + ex == LLAMA_EXAMPLE_MAIN ? "number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)" : "number of tokens to predict (default: %d, -1 = infinity)", params.n_predict), @@ -1331,6 +1456,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.n_keep = value; } )); + add_opt(common_arg( + {"--swa-full"}, + string_format("use full-size SWA cache (default: %s)\n" + "[(more info)](https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)", params.swa_full ? "true" : "false"), + [](common_params & params) { + params.swa_full = true; + } + ).set_env("LLAMA_ARG_SWA_FULL")); add_opt(common_arg( {"--no-context-shift"}, string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"), @@ -1378,13 +1511,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"-f", "--file"}, "FNAME", "a file containing the prompt (default: none)", [](common_params & params, const std::string & value) { - std::ifstream file(value); - if (!file) { - throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); - } + params.prompt = read_file(value); // store the external file name in params params.prompt_file = value; - std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.prompt)); if (!params.prompt.empty() && params.prompt.back() == '\n') { params.prompt.pop_back(); } @@ -1394,11 +1523,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"-sysf", "--system-prompt-file"}, "FNAME", "a file containing the system prompt (default: none)", [](common_params & params, const std::string & value) { - std::ifstream file(value); - if (!file) { - throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); - } - std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.system_prompt)); + params.system_prompt = read_file(value); if (!params.system_prompt.empty() && params.system_prompt.back() == '\n') { params.system_prompt.pop_back(); } @@ -1549,7 +1674,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.input_prefix = value; params.enable_chat_template = false; } - ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_INFILL})); + ).set_examples({LLAMA_EXAMPLE_MAIN})); add_opt(common_arg( {"--in-suffix"}, "STRING", "string to suffix after user inputs with (default: empty)", @@ -1557,14 +1682,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.input_suffix = value; params.enable_chat_template = false; } - ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_INFILL})); + ).set_examples({LLAMA_EXAMPLE_MAIN})); add_opt(common_arg( {"--no-warmup"}, "skip warming up the model with an empty run", [](common_params & params) { params.warmup = false; } - ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING})); + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL})); add_opt(common_arg( {"--spm-infill"}, string_format( @@ -1574,7 +1699,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.spm_infill = true; } - ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_INFILL})); + ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--samplers"}, "SAMPLERS", string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()), @@ -1823,15 +1948,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--grammar-file"}, "FNAME", "file to read grammar from", [](common_params & params, const std::string & value) { - std::ifstream file(value); - if (!file) { - throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); - } - std::copy( - std::istreambuf_iterator(file), - std::istreambuf_iterator(), - std::back_inserter(params.sampling.grammar) - ); + params.sampling.grammar = read_file(value); } ).set_sparam()); add_opt(common_arg( @@ -1841,6 +1958,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.sampling.grammar = json_schema_to_grammar(json::parse(value)); } ).set_sparam()); + add_opt(common_arg( + {"-jf", "--json-schema-file"}, "FILE", + "File containing a JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\nFor schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead", + [](common_params & params, const std::string & value) { + std::ifstream file(value); + if (!file) { + throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); + } + std::string schema; + std::copy( + std::istreambuf_iterator(file), + std::istreambuf_iterator(), + std::back_inserter(schema) + ); + params.sampling.grammar = json_schema_to_grammar(json::parse(schema)); + } + ).set_sparam()); add_opt(common_arg( {"--pooling"}, "{none,mean,cls,last,rank}", "pooling type for embeddings, use model default if unspecified", @@ -1942,13 +2076,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.grp_attn_w = value; } ).set_env("LLAMA_ARG_GRP_ATTN_W").set_examples({LLAMA_EXAMPLE_MAIN})); - add_opt(common_arg( - {"-dkvc", "--dump-kv-cache"}, - "verbose print of the KV cache", - [](common_params & params) { - params.dump_kv_cache = true; - } - )); add_opt(common_arg( {"-nkvo", "--no-kv-offload"}, "disable KV offload", @@ -1982,13 +2109,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.cache_type_v = kv_cache_type_from_str(value); } ).set_env("LLAMA_ARG_CACHE_TYPE_V")); - add_opt(common_arg( - {"--perplexity", "--all-logits"}, - string_format("return logits for all tokens in the batch (default: %s)", params.logits_all ? "true" : "false"), - [](common_params & params) { - params.logits_all = true; - } - ).set_examples({LLAMA_EXAMPLE_PERPLEXITY})); add_opt(common_arg( {"--hellaswag"}, "compute HellaSwag score over random tasks from datafile supplied with -f", @@ -2096,25 +2216,40 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_CONT_BATCHING")); add_opt(common_arg( {"--mmproj"}, "FILE", - "path to a multimodal projector file for LLaVA. see examples/llava/README.md", + "path to a multimodal projector file. see tools/mtmd/README.md\n" + "note: if -hf is used, this argument can be omitted", [](common_params & params, const std::string & value) { params.mmproj.path = value; } - ).set_examples({LLAMA_EXAMPLE_LLAVA})); + ).set_examples(mmproj_examples).set_env("LLAMA_ARG_MMPROJ")); add_opt(common_arg( {"--mmproj-url"}, "URL", - "URL to a multimodal projector file for LLaVA. see examples/llava/README.md", + "URL to a multimodal projector file. see tools/mtmd/README.md", [](common_params & params, const std::string & value) { params.mmproj.url = value; } - ).set_examples({LLAMA_EXAMPLE_LLAVA})); + ).set_examples(mmproj_examples).set_env("LLAMA_ARG_MMPROJ_URL")); add_opt(common_arg( - {"--image"}, "FILE", - "path to an image file. use with multimodal models. Specify multiple times for batching", + {"--no-mmproj"}, + "explicitly disable multimodal projector, useful when using -hf", + [](common_params & params) { + params.no_mmproj = true; + } + ).set_examples(mmproj_examples).set_env("LLAMA_ARG_NO_MMPROJ")); + add_opt(common_arg( + {"--no-mmproj-offload"}, + "do not offload multimodal projector to GPU", + [](common_params & params) { + params.mmproj_use_gpu = false; + } + ).set_examples(mmproj_examples).set_env("LLAMA_ARG_NO_MMPROJ_OFFLOAD")); + add_opt(common_arg( + {"--image", "--audio"}, "FILE", + "path to an image or audio file. use with multimodal models, can be repeated if you have multiple files\n", [](common_params & params, const std::string & value) { params.image.emplace_back(value); } - ).set_examples({LLAMA_EXAMPLE_LLAVA})); + ).set_examples({LLAMA_EXAMPLE_MTMD})); if (llama_supports_rpc()) { add_opt(common_arg( {"--rpc"}, "SERVERS", @@ -2314,6 +2449,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } } )); + add_opt(common_arg( + {"--no-op-offload"}, + string_format("disable offloading host tensor operations to device (default: %s)", params.no_op_offload ? "true" : "false"), + [](common_params & params) { + params.no_op_offload = true; + } + )); add_opt(common_arg( {"--lora"}, "FNAME", "path to LoRA adapter (can be repeated to use multiple adapters)", @@ -2382,6 +2524,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex add_opt(common_arg( {"-hf", "-hfr", "--hf-repo"}, "/[:quant]", "Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n" + "mmproj is also downloaded automatically if available. to disable, add --no-mmproj\n" "example: unsloth/phi-4-GGUF:q4_k_m\n" "(default: unused)", [](common_params & params, const std::string & value) { @@ -2454,7 +2597,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, int value) { params.n_junk = value; } - ).set_examples({LLAMA_EXAMPLE_PASSKEY})); + ).set_examples({LLAMA_EXAMPLE_PASSKEY, LLAMA_EXAMPLE_PARALLEL})); add_opt(common_arg( {"--pos"}, "N", string_format("position of the passkey in the junk text (default: %d)", params.i_pos), @@ -2504,13 +2647,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.i_chunk = value; } ).set_examples({LLAMA_EXAMPLE_IMATRIX})); + add_opt(common_arg( + {"--parse-special"}, + string_format("prase special tokens (chat, tool, etc) (default: %s)", params.parse_special ? "true" : "false"), + [](common_params & params) { + params.parse_special = true; + } + ).set_examples({LLAMA_EXAMPLE_IMATRIX})); add_opt(common_arg( {"-pps"}, string_format("is the prompt shared across parallel sequences (default: %s)", params.is_pp_shared ? "true" : "false"), [](common_params & params) { params.is_pp_shared = true; } - ).set_examples({LLAMA_EXAMPLE_BENCH})); + ).set_examples({LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL})); add_opt(common_arg( {"-npp"}, "n0,n1,...", "number of prompt tokens", @@ -2593,9 +2743,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS")); add_opt(common_arg( {"--reranking", "--rerank"}, - string_format("enable reranking endpoint on server (default: %s)", params.reranking ? "enabled" : "disabled"), + string_format("enable reranking endpoint on server (default: %s)", "disabled"), [](common_params & params) { - params.reranking = true; + params.embedding = true; + params.pooling_type = LLAMA_POOLING_TYPE_RANK; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING")); add_opt(common_arg( @@ -2653,7 +2804,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_THREADS_HTTP")); add_opt(common_arg( {"--cache-reuse"}, "N", - string_format("min chunk size to attempt reusing from the cache via KV shifting (default: %d)", params.n_cache_reuse), + string_format( + "min chunk size to attempt reusing from the cache via KV shifting (default: %d)\n" + "[(card)](https://ggml.ai/f0.png)", params.n_cache_reuse + ), [](common_params & params, int value) { params.n_cache_reuse = value; } @@ -2706,15 +2860,25 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA")); add_opt(common_arg( {"--reasoning-format"}, "FORMAT", - "reasoning format (default: deepseek; allowed values: deepseek, none)\n" - "controls whether thought tags are extracted from the response, and in which format they're returned. 'none' leaves thoughts unparsed in `message.content`, 'deepseek' puts them in `message.reasoning_content` (for DeepSeek R1 & Command R7B only).\n" - "only supported for non-streamed responses", + "controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n" + "- none: leaves thoughts unparsed in `message.content`\n" + "- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)\n" + "(default: deepseek)", [](common_params & params, const std::string & value) { /**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; } + else if (value == "deepseek-legacy") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY; } else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; } - else { std::invalid_argument("invalid value"); } + else { throw std::invalid_argument("invalid value"); } } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK")); + add_opt(common_arg( + {"--reasoning-budget"}, "N", + "controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)", + [](common_params & params, int value) { + if (value != 0 && value != -1) { throw std::invalid_argument("invalid value"); } + params.reasoning_budget = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK_BUDGET")); add_opt(common_arg( {"--chat-template"}, "JINJA_TEMPLATE", string_format( @@ -2726,7 +2890,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.chat_template = value; } - ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE")); + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_CHAT_TEMPLATE")); add_opt(common_arg( {"--chat-template-file"}, "JINJA_TEMPLATE_FILE", string_format( @@ -2736,16 +2900,19 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "list of built-in templates:\n%s", list_builtin_chat_templates().c_str() ), [](common_params & params, const std::string & value) { - std::ifstream file(value); - if (!file) { - throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); - } - std::copy( - std::istreambuf_iterator(file), - std::istreambuf_iterator(), - std::back_inserter(params.chat_template)); + params.chat_template = read_file(value); } ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE")); + add_opt(common_arg( + {"--no-prefill-assistant"}, + string_format( + "whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)\n" + "when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled\n" + ), + [](common_params & params) { + params.prefill_assistant = false; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_PREFILL_ASSISTANT")); add_opt(common_arg( {"-sps", "--slot-prompt-similarity"}, "SIMILARITY", string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity), @@ -2766,7 +2933,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.simple_io = true; } - ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_INFILL})); + ).set_examples({LLAMA_EXAMPLE_MAIN})); add_opt(common_arg( {"--positive-file"}, "FNAME", string_format("positive prompts file, one prompt per line (default: '%s')", params.cvector_positive_file.c_str()), @@ -2810,7 +2977,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { /**/ if (value == "jsonl") { params.batched_bench_output_jsonl = true; } else if (value == "md") { params.batched_bench_output_jsonl = false; } - else { std::invalid_argument("invalid value"); } + else { throw std::invalid_argument("invalid value"); } } ).set_examples({LLAMA_EXAMPLE_BENCH})); add_opt(common_arg( @@ -2842,6 +3009,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex common_log_set_verbosity_thold(INT_MAX); } )); + add_opt(common_arg( + {"--offline"}, + "Offline mode: forces use of cache, prevents network access", + [](common_params & params) { + params.offline = true; + } + ).set_env("LLAMA_OFFLINE")); add_opt(common_arg( {"-lv", "--verbosity", "--log-verbosity"}, "N", "Set the verbosity threshold. Messages with a higher verbosity will be ignored.", diff --git a/common/arg.h b/common/arg.h index 49ab8667b..70bea100f 100644 --- a/common/arg.h +++ b/common/arg.h @@ -78,3 +78,12 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e // function to be used by test-arg-parser common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); +bool common_has_curl(); + +struct common_remote_params { + std::vector headers; + long timeout = 0; // CURLOPT_TIMEOUT, in seconds ; 0 means no timeout + long max_size = 0; // max size of the response ; unlimited if 0 ; max is 2GB +}; +// get remote file content, returns +std::pair> common_remote_get_content(const std::string & url, const common_remote_params & params); diff --git a/common/build-info.cpp.in b/common/build-info.cpp.in index 0b945aa68..aee9d7eaf 100644 --- a/common/build-info.cpp.in +++ b/common/build-info.cpp.in @@ -1,4 +1,4 @@ -int LLAMA_BUILD_NUMBER = @BUILD_NUMBER@; -char const *LLAMA_COMMIT = "@BUILD_COMMIT@"; +int LLAMA_BUILD_NUMBER = @LLAMA_BUILD_NUMBER@; +char const *LLAMA_COMMIT = "@LLAMA_BUILD_COMMIT@"; char const *LLAMA_COMPILER = "@BUILD_COMPILER@"; char const *LLAMA_BUILD_TARGET = "@BUILD_TARGET@"; diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp new file mode 100644 index 000000000..18a30e49a --- /dev/null +++ b/common/chat-parser.cpp @@ -0,0 +1,385 @@ +#include "chat-parser.h" +#include "common.h" +#include "log.h" +#include "regex-partial.h" + +#include +#include +#include +#include + +using json = nlohmann::ordered_json; + +common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax) + : input_(input), is_partial_(is_partial), syntax_(syntax) +{ + result_.role = "assistant"; + + while (true) { + std::string id = std::to_string(std::rand()); + if (input.find(id) == std::string::npos) { + healing_marker_ = id; + break; + } + } +} + +std::string common_chat_msg_parser::str(const common_string_range & rng) const { + GGML_ASSERT(rng.begin <= rng.end); + return input_.substr(rng.begin, rng.end - rng.begin); +} + +void common_chat_msg_parser::add_content(const std::string &content) { + result_.content += content; +} + +void common_chat_msg_parser::add_reasoning_content(const std::string &reasoning_content) { + result_.reasoning_content += reasoning_content; +} + +bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::string & id, const std::string & arguments) { + if (name.empty()) { + return false; + } + + common_chat_tool_call tool_call; + tool_call.name = name; + tool_call.arguments = arguments; + tool_call.id = id; + + // LOG_DBG("Tool call arguments:\n\traw: %s\n\tresult: %s\n", arguments.c_str(), tool_call.arguments.c_str()); + result_.tool_calls.emplace_back(tool_call); + + return true; +} +bool common_chat_msg_parser::add_tool_call(const json & tool_call) { + std::string name = tool_call.contains("name") ? tool_call.at("name") : ""; + std::string id = tool_call.contains("id") ? tool_call.at("id") : ""; + std::string arguments = tool_call.contains("arguments") ? tool_call.at("arguments") : ""; + return add_tool_call(name, id, arguments); +} + +bool common_chat_msg_parser::add_tool_calls(const json & arr) { + for (const auto & item : arr) { + if (!add_tool_call(item)) { + return false; + } + } + return true; +} +void common_chat_msg_parser::finish() { + if (!is_partial_ && pos_ != input_.size()) { + throw std::runtime_error("Unexpected content at end of input");// + input_.substr(pos_)); + } +} + +bool common_chat_msg_parser::consume_spaces() { + const auto length = input_.size(); + auto consumed = false; + while (pos_ < length && std::isspace(input_[pos_])) { + ++pos_; + consumed = true; + } + return consumed; +} + +bool common_chat_msg_parser::try_consume_literal(const std::string & literal) { + auto pos = pos_; + for (auto i = 0u; i < literal.size(); ++i) { + if (pos >= input_.size()) { + return false; + } + if (input_[pos] != literal[i]) { + return false; + } + ++pos; + } + pos_ = pos; + return true; +} + +std::optional common_chat_msg_parser::try_find_literal(const std::string & literal) { + auto idx = input_.find(literal, pos_); + if (idx != std::string::npos) { + find_regex_result res; + res.prelude = input_.substr(pos_, idx - pos_); + auto end = idx + literal.size(); + res.groups.emplace_back(common_string_range{idx, end}); + move_to(end); + return res; + } + if (is_partial_) { + idx = string_find_partial_stop(input_, literal); + if (idx != std::string::npos && idx >= pos_) { + find_regex_result res; + res.prelude = input_.substr(pos_, idx - pos_); + auto end = input_.size(); + res.groups.emplace_back(common_string_range{idx, end}); + move_to(end); + return res; + } + } + return std::nullopt; +} + +void common_chat_msg_parser::consume_literal(const std::string & literal) { + if (!try_consume_literal(literal)) { + throw common_chat_msg_partial_exception(literal); + } +} + +bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think, const std::string & end_think) { + auto handle_reasoning = [&](const std::string & reasoning, bool closed) { + auto stripped_reasoning = string_strip(reasoning); + if (stripped_reasoning.empty()) { + return; + } + if (syntax_.reasoning_in_content) { + add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "" : start_think); + add_content(stripped_reasoning); + if (closed) { + add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "" : end_think); + } + } else { + add_reasoning_content(stripped_reasoning); + } + }; + if (syntax_.reasoning_format != COMMON_REASONING_FORMAT_NONE) { + if (syntax_.thinking_forced_open || try_consume_literal(start_think)) { + if (auto res = try_find_literal(end_think)) { + handle_reasoning(res->prelude, /* closed */ true); + consume_spaces(); + return true; + } + auto rest = consume_rest(); + if (!rest.empty()) { + handle_reasoning(rest, /* closed */ !is_partial()); + } + // Allow unclosed thinking tags, for now (https://github.com/ggml-org/llama.cpp/issues/13812, https://github.com/ggml-org/llama.cpp/issues/13877) + // if (!syntax_.thinking_forced_open) { + // throw common_chat_msg_partial_exception(end_think); + // } + return true; + } + } + return false; +} + +std::string common_chat_msg_parser::consume_rest() { + auto rest = input_.substr(pos_); + pos_ = input_.size(); + return rest; +} + +// Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback. +std::optional common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from, bool add_prelude_to_content) { + auto m = regex.search(input_, from == std::string::npos ? pos_ : from); + if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) { + return std::nullopt; + } + auto prelude = input_.substr(pos_, m.groups[0].begin - pos_); + pos_ = m.groups[0].end; + + if (add_prelude_to_content) { + add_content(prelude); + } + if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) { + if (is_partial()) { + throw common_chat_msg_partial_exception(regex.str()); + } + return std::nullopt; + } + return find_regex_result{prelude, m.groups}; +} + +common_chat_msg_parser::find_regex_result common_chat_msg_parser::consume_regex(const common_regex & regex) { + if (auto result = try_consume_regex(regex)) { + return *result; + } + throw common_chat_msg_partial_exception(regex.str()); +} + +std::optional common_chat_msg_parser::try_consume_regex(const common_regex & regex) { + auto m = regex.search(input_, pos_); + if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) { + return std::nullopt; + } + if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) { + if (is_partial()) { + throw common_chat_msg_partial_exception(regex.str()); + } + return std::nullopt; + } + if (m.groups[0].begin != pos_) { + // Didn't match at the current position. + return std::nullopt; + } + pos_ = m.groups[0].end; + + return find_regex_result { + /* .prelude = */ "", + m.groups, + }; +} + +std::optional common_chat_msg_parser::try_consume_json() { + auto it = input_.cbegin() + pos_; + const auto end = input_.cend(); + common_json result; + if (!common_json_parse(it, end, healing_marker_, result)) { + return std::nullopt; + } + pos_ = std::distance(input_.cbegin(), it); + if (result.healing_marker.marker.empty()) { + // No healing marker, just return the parsed json + return result; + } + if (!is_partial()) { + throw common_chat_msg_partial_exception("JSON"); + } + return result; +} + +common_json common_chat_msg_parser::consume_json() { + if (auto result = try_consume_json()) { + return *result; + } + throw common_chat_msg_partial_exception("JSON"); +} + +common_chat_msg_parser::consume_json_result common_chat_msg_parser::consume_json_with_dumped_args( + const std::vector> & args_paths, + const std::vector> & content_paths +) { + if (auto result = try_consume_json_with_dumped_args(args_paths, content_paths)) { + return *result; + } + throw common_chat_msg_partial_exception("JSON"); +} + +std::optional common_chat_msg_parser::try_consume_json_with_dumped_args( + const std::vector> & args_paths, + const std::vector> & content_paths +) { + auto partial = try_consume_json(); + if (!partial) { + return std::nullopt; + } + auto is_arguments_path = [&](const std::vector & path) { + return std::find(args_paths.begin(), args_paths.end(), path) != args_paths.end(); + }; + auto is_content_path = [&](const std::vector & path) { + return std::find(content_paths.begin(), content_paths.end(), path) != content_paths.end(); + }; + + if (partial->healing_marker.marker.empty()) { + if (args_paths.empty()) { + // No arguments to dump, and JSON was parsed fully. + return consume_json_result { + partial->json, + /* .is_partial = */ false, + }; + } + if (is_arguments_path({})) { + // Entire JSON is the arguments and was parsed fully. + return consume_json_result { + partial->json.dump(), + /* .is_partial = */ false, + }; + } + } + + LOG_DBG("Parsed partial JSON: %s (json_healing_marker: %s)\n", partial->json.dump().c_str(), partial->healing_marker.json_dump_marker.c_str()); + + auto found_healing_marker = false; + std::vector path; + std::function remove_unsupported_healings_and_dump_args = [&](const json & j) -> json { + if (is_arguments_path(path)) { + auto arguments = j.dump(); + if (is_partial() && !partial->healing_marker.marker.empty()) { + auto idx = arguments.find(partial->healing_marker.json_dump_marker); + if (idx != std::string::npos) { + arguments.resize(idx); + found_healing_marker = true; + } + if (arguments == "\"") { + // This happens because of completing `:"$magic` after `"arguments"` + arguments = ""; + } + } + return arguments; + } + if (is_content_path(path)) { + if (!j.is_string()) { + throw std::runtime_error("Content path must be a string"); + } + std::string str = j; + auto idx = str.find(partial->healing_marker.marker); // not using json_dump_marker as we're inside a string + if (idx != std::string::npos) { + str.resize(idx); + found_healing_marker = true; + } + return str; + } + if (j.is_object()) { + auto obj = json::object(); + for (const auto & p : j.items()) { + const auto & key = p.key(); + const auto & value = p.value(); + const std::string key_str = key; // NOLINT + auto idx = key_str.find(healing_marker_); + if (idx != std::string::npos) { + found_healing_marker = true; + break; + } + path.push_back(key_str); + if (value.is_string()) { + const std::string value_str = value; + if (value_str.find(healing_marker_) != std::string::npos) { + found_healing_marker = true; + if (is_content_path(path)) { + if (partial->healing_marker.marker == partial->healing_marker.json_dump_marker) { + // The healing occurred inside the string: good. Otherwise we just ditch the entire key/value pair. + obj[key] = remove_unsupported_healings_and_dump_args(value); + } + } + break; + } + obj[key] = value; + } else { + obj[key] = remove_unsupported_healings_and_dump_args(value); + } + path.pop_back(); + } + return obj; + } + if (j.is_array()) { + auto arr = json::array(); + for (const auto & value : j) { + if (value.is_string()) { + std::string str = value; + auto idx = str.find(healing_marker_); + if (idx != std::string::npos) { + // Don't heal array values that aren't in the arguments. + found_healing_marker = true; + break; + } + } + arr.push_back(remove_unsupported_healings_and_dump_args(value)); + } + return arr; + } + return j; + }; + + auto cleaned = remove_unsupported_healings_and_dump_args(partial->json); + LOG_DBG("Cleaned up JSON %s to %s (json_healing_marker : '%s')\n", partial->json.dump().c_str(), cleaned.dump().c_str(), partial->healing_marker.json_dump_marker.c_str()); + return consume_json_result { + cleaned, + /* .is_partial = */ found_healing_marker, + }; +} + +void common_chat_msg_parser::clear_tools() { + result_.tool_calls.clear(); +} diff --git a/common/chat-parser.h b/common/chat-parser.h new file mode 100644 index 000000000..0e64c341a --- /dev/null +++ b/common/chat-parser.h @@ -0,0 +1,120 @@ +#pragma once + +#include "chat.h" +#include "json-partial.h" +#include "regex-partial.h" + +#include + +#include +#include +#include + +class common_chat_msg_partial_exception : public std::runtime_error { + public: + common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {} +}; + +class common_chat_msg_parser { + std::string input_; + bool is_partial_; + common_chat_syntax syntax_; + std::string healing_marker_; + + size_t pos_ = 0; + common_chat_msg result_; + + public: + common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax); + const std::string & input() const { return input_; } + size_t pos() const { return pos_; } + const std::string & healing_marker() const { return healing_marker_; } + const bool & is_partial() const { return is_partial_; } + const common_chat_msg & result() const { return result_; } + const common_chat_syntax & syntax() const { return syntax_; } + + void move_to(size_t pos) { + if (pos > input_.size()) { + throw std::runtime_error("Invalid position!"); + } + pos_ = pos; + } + void move_back(size_t n) { + if (pos_ < n) { + throw std::runtime_error("Can't move back that far!"); + } + pos_ -= n; + } + + // Get the substring of the input at the given range + std::string str(const common_string_range & rng) const; + + // Appends to the result.content field + void add_content(const std::string & content); + + // Appends to the result.reasoning_content field + void add_reasoning_content(const std::string & reasoning_content); + + // Adds a tool call to the result. If the tool call is too incomplete (e.g. name empty), it won't add anything. + bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments); + + // Adds a tool call using the "name", "id" and "arguments" fields of the json object + bool add_tool_call(const nlohmann::ordered_json & tool_call); + + // Adds an array of tool calls using their "name", "id" and "arguments" fields. + bool add_tool_calls(const nlohmann::ordered_json & arr); + + void finish(); + + bool consume_spaces(); + + void consume_literal(const std::string & literal); + + bool try_parse_reasoning(const std::string & start_think, const std::string & end_think); + + std::string consume_rest(); + + struct find_regex_result { + std::string prelude; + std::vector groups; + }; + + std::optional try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true); + + bool try_consume_literal(const std::string & literal); + + std::optional try_find_literal(const std::string & literal); + + find_regex_result consume_regex(const common_regex & regex); + + std::optional try_consume_regex(const common_regex & regex); + + std::optional try_consume_json(); + common_json consume_json(); + + struct consume_json_result { + nlohmann::ordered_json value; + bool is_partial; + }; + + /* + Consume (possibly partial) json and converts specific subtrees to (possibly truncated) JSON strings. + + By default, object keys can't be truncated, nor can string values (their corresponding key is removed, + e.g. `{"foo": "bar", "baz": "b` -> `{"foo": "bar"}` + + But one can allow subpaths to be kept truncated, and possibly json-dumped to truncated json strings + - with `content_paths={{"foo"}}` -> `{"foo": "b` -> {"foo": "b"}` + - with `args_paths={{"foo"}}` -> `{"foo": {"b` -> `{"foo": "{b"}` + */ + consume_json_result consume_json_with_dumped_args( + const std::vector> & args_paths = {}, + const std::vector> & content_paths = {} + ); + std::optional try_consume_json_with_dumped_args( + const std::vector> & args_paths = {}, + const std::vector> & content_paths = {} + ); + + void clear_tools(); +}; diff --git a/common/chat.cpp b/common/chat.cpp index bbc5f087c..7d9aaeb12 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1,10 +1,125 @@ #include "chat.h" +#include "chat-parser.h" +#include "common.h" +#include "json-partial.h" #include "json-schema-to-grammar.h" #include "log.h" -#include "minja/chat-template.hpp" -#include "minja/minja.hpp" +#include "regex-partial.h" +#include +#include + +#include +#include +#include #include +#include +#include +#include + +static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) { + auto time = std::chrono::system_clock::to_time_t(now); + auto local_time = *std::localtime(&time); + std::ostringstream ss; + ss << std::put_time(&local_time, format.c_str()); + auto res = ss.str(); + return res; +} + +static std::string string_diff(const std::string & last, const std::string & current) { + if (last.empty()) { + return current; + } + if (!string_starts_with(current, last)) { + if (string_starts_with(last, current)) { + // This happens if the last generation ended on a partial stop word (not erased), + // and the current ended on a stop word (erased). + return ""; + } + throw std::runtime_error("Invalid diff: '" + last + "' not found at start of '" + current + "'"); + } + return current.substr(last.size()); +} + +static bool has_content_or_tool_calls(const common_chat_msg & msg) { + return !msg.content.empty() || !msg.tool_calls.empty(); +} + +template <> +json common_chat_msg::to_json_oaicompat() const +{ + json message { + {"role", "assistant"}, + }; + if (!reasoning_content.empty()) { + message["reasoning_content"] = reasoning_content; + } + if (content.empty() && !tool_calls.empty()) { + message["content"] = json(); + } else { + message["content"] = content; + } + if (!tool_calls.empty()) { + auto arr = json::array(); + for (const auto & tc : tool_calls) { + arr.push_back({ + {"type", "function"}, + {"function", { + {"name", tc.name}, + {"arguments", tc.arguments}, + }}, + {"id", tc.id}, + // // Some templates generate and require an id (sometimes in a very specific format, e.g. Mistral Nemo). + // // We only generate a random id for the ones that don't generate one by themselves + // // (they also won't get to see it as their template likely doesn't use it, so it's all for the client) + // {"id", tc.id.empty() ? gen_tool_call_id() : tc.id}, + }); + } + message["tool_calls"] = arr; + } + return message; +} + +std::vector common_chat_msg_diff::compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg) { + std::vector diffs; + if (previous_msg.reasoning_content != new_msg.reasoning_content) { + auto & diff = diffs.emplace_back(); + diff.reasoning_content_delta = string_diff(previous_msg.reasoning_content, new_msg.reasoning_content); + } + if (previous_msg.content != new_msg.content) { + auto & diff = diffs.emplace_back(); + diff.content_delta = string_diff(previous_msg.content, new_msg.content); + } + + if (new_msg.tool_calls.size() < previous_msg.tool_calls.size()) { + throw std::runtime_error("Invalid diff: now finding less tool calls!"); + } + + if (!previous_msg.tool_calls.empty()) { + auto idx = previous_msg.tool_calls.size() - 1; + const auto & pref = previous_msg.tool_calls[idx]; + const auto & newf = new_msg.tool_calls[idx]; + if (pref.name != newf.name) { + throw std::runtime_error("Invalid diff: tool call mismatch!"); + } + auto args_diff = string_diff(pref.arguments, newf.arguments); + if (!args_diff.empty() || pref.id != newf.id) { + auto & diff = diffs.emplace_back(); + diff.tool_call_index = idx; + if (pref.id != newf.id) { + diff.tool_call_delta.id = newf.id; + diff.tool_call_delta.name = newf.name; + } + diff.tool_call_delta.arguments = args_diff; + } + } + for (size_t idx = previous_msg.tool_calls.size(); idx < new_msg.tool_calls.size(); ++idx) { + auto & diff = diffs.emplace_back(); + diff.tool_call_index = idx; + diff.tool_call_delta = new_msg.tool_calls[idx]; + } + return diffs; +} typedef minja::chat_template common_chat_template; @@ -23,7 +138,8 @@ struct templates_params { bool stream; std::string grammar; bool add_generation_prompt = true; - bool extract_reasoning = true; + bool enable_thinking = true; + std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); }; common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) { @@ -125,7 +241,9 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa msgs.push_back(msg); } } catch (const std::exception & e) { - throw std::runtime_error("Failed to parse messages: " + std::string(e.what()) + "; messages = " + messages.dump(2)); + // @ngxson : disable otherwise it's bloating the API response + // printf("%s\n", std::string("; messages = ") + messages.dump(2)); + throw std::runtime_error("Failed to parse messages: " + std::string(e.what())); } return msgs; @@ -265,6 +383,32 @@ json common_chat_tools_to_json_oaicompat(const std::vector & t return result; } +template <> json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) { + json delta = json::object(); + if (!diff.reasoning_content_delta.empty()) { + delta["reasoning_content"] = diff.reasoning_content_delta; + } + if (!diff.content_delta.empty()) { + delta["content"] = diff.content_delta; + } + if (diff.tool_call_index != std::string::npos) { + json tool_call; + tool_call["index"] = diff.tool_call_index; + if (!diff.tool_call_delta.id.empty()) { + tool_call["id"] = diff.tool_call_delta.id; + tool_call["type"] = "function"; + } + json function = json::object(); + if (!diff.tool_call_delta.name.empty()) { + function["name"] = diff.tool_call_delta.name; + } + function["arguments"] = diff.tool_call_delta.arguments; + tool_call["function"] = function; + delta["tool_calls"] = json::array({tool_call}); + } + return delta; +} + bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { if (use_jinja) { try { @@ -432,7 +576,7 @@ common_chat_templates_ptr common_chat_templates_init( return tmpls; } -std::string common_chat_format_name(common_chat_format format) { +const char * common_chat_format_name(common_chat_format format) { switch (format) { case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only"; case COMMON_CHAT_FORMAT_GENERIC: return "Generic"; @@ -440,182 +584,128 @@ std::string common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x"; case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools"; case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1"; - case COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING: return "DeepSeek R1 (extract reasoning)"; case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2"; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2"; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1"; case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro"; - case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING: return "Hermes 2 Pro (extract reasoning)"; case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B"; - case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING: return "Command R7B (extract reasoning)"; default: throw std::runtime_error("Unknown chat format"); } } -static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) { - // // https://json.nlohmann.me/features/parsing/sax_interface/ - struct json_error_locator : public nlohmann::json_sax { - std::size_t position; - bool found_error; +const char * common_reasoning_format_name(common_reasoning_format format) { + switch (format) { + case COMMON_REASONING_FORMAT_NONE: return "none"; + case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek"; + case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy"; + default: + throw std::runtime_error("Unknown reasoning format"); + } +} - json_error_locator() : position(0), found_error(false) {} - - bool parse_error(std::size_t position, const std::string &, const json::exception &) override { // NOLINT - this->position = position - 1; - this->found_error = true; - return false; +static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) { + std::string arguments; + if (builder.is_partial()) { + arguments = (json {{"code", code + builder.healing_marker()}}).dump(); + auto idx = arguments.find(builder.healing_marker()); + if (idx != std::string::npos) { + arguments.resize(idx); } - bool null() override { return true; } // NOLINT - bool boolean(bool) override { return true; } // NOLINT - bool number_integer(number_integer_t) override { return true; } // NOLINT - bool number_unsigned(number_unsigned_t) override { return true; } // NOLINT - bool number_float(number_float_t, const string_t &) override { return true; } // NOLINT - bool string(string_t &) override { return true; } // NOLINT - bool binary(binary_t &) override { return true; } // NOLINT - bool start_object(std::size_t) override { return true; } // NOLINT - bool key(string_t &) override { return true; } // NOLINT - bool end_object() override { return true; } - bool start_array(std::size_t) override { return true; } // NOLINT - bool end_array() override { return true; } - }; - json_error_locator err_loc; - json::sax_parse(it, end, &err_loc); - - std::string::const_iterator temptative_end; - if (err_loc.found_error) { - temptative_end = it + err_loc.position; } else { - temptative_end = end; - } - std::string json_sub {it, temptative_end}; - try { - out = json::parse(json_sub); - it = temptative_end; - return true; - } catch (const std::exception &) { - return false; - } -} - -static bool parse_literal(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) { - auto expected_it = expected.begin(); - auto tmp_it = it; - while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) { - ++tmp_it; - ++expected_it; - } - if (expected_it == expected.end()) { - it = tmp_it; - return true; - } - return false; -} - -static std::optional parse_pattern(std::string::const_iterator & it, const std::string::const_iterator & end, const std::regex & expected) { - std::smatch match; - if (std::regex_match(it, end, match, expected)) { - it = match.suffix().first; - return match; - } - return std::nullopt; -} - -static void consume_spaces(std::string::const_iterator & it, const std::string::const_iterator & end) { - while (it != end && std::isspace(*it)) { - ++it; + arguments = (json {{"code", code}}).dump(); } + return arguments; } /** * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. * Aggregates the prefix, suffix and in-between text into the content. */ -static common_chat_msg parse_json_tool_calls( - const std::string& input, - const std::optional & trigger_opt, - const std::regex & function_regex, - const std::regex & close_regex, - bool allow_raw_python = false) { - std::smatch match; +static void parse_json_tool_calls( + common_chat_msg_parser & builder, + const std::optional & block_open, + const std::optional & function_regex_start_only, + const std::optional & function_regex, + const common_regex & close_regex, + const std::optional & block_close, + bool allow_raw_python = false, + const std::function & get_function_name = nullptr) { - common_chat_msg result; - result.role = "assistant"; + auto parse_tool_calls = [&]() { + size_t from = std::string::npos; + auto first = true; + while (true) { + auto res = function_regex_start_only && first + ? builder.try_consume_regex(*function_regex_start_only) + : function_regex + ? builder.try_find_regex(*function_regex, from) + : std::nullopt; + if (res) { + std::string name; + if (get_function_name) { + name = get_function_name(*res); + } else { + GGML_ASSERT(res->groups.size() == 2); + name = builder.str(res->groups[1]); + } + first = false; + if (name.empty()) { + // get_function_name signalled us that we should skip this match and treat it as content. + from = res->groups[0].begin + 1; + continue; + } + from = std::string::npos; - - auto end = input.end(); - auto it = input.begin(); - - if (trigger_opt) { - if (!std::regex_search(it, end, match, *trigger_opt)) { - result.content = input; - return result; - } - result.content = match.prefix().str(); - it = match.suffix().first; - } - - while (it != end) { - std::sregex_iterator rend; - std::sregex_iterator rit(it, end, function_regex); - if (rit == rend) { - result.content += std::string(it, end); + auto maybe_raw_python = name == "python" && allow_raw_python; + if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) { + if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { + if (!builder.add_tool_call(name, "", arguments->value) || arguments->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_regex(close_regex); + } + continue; + } + if (maybe_raw_python) { + auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); + if (!builder.add_tool_call(name, "", arguments)) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + return; + } + throw common_chat_msg_partial_exception("incomplete tool call"); + } break; } - auto name = rit->str(1); - result.content += std::string(it, rit->prefix().second); - it = rit->suffix().first; - - json arguments; - if (parse_json(it, end, arguments)) { - if (!std::regex_search(it, end, match, close_regex)) { - throw std::runtime_error("Malformed input, missing closing pattern: " + input); - } - it = match.suffix().first; - result.tool_calls.push_back({name, arguments.is_string() ? arguments.get() : arguments.dump(), /* id= */ ""}); - } else { - if (allow_raw_python && name == "python") { - result.tool_calls.push_back({name, json({{"code", std::string(it, end)}}).dump(), /* id= */ ""}); - break; - } - throw std::runtime_error("Failed to parse json tool call arguments: " + input); + if (block_close) { + builder.consume_regex(*block_close); } - } - - if (!result.tool_calls.empty()) { - if (!string_strip(result.content).empty()) { - LOG_WRN("Content found with tool calls: %s\n", result.content.c_str()); - } - result.content = ""; - } - return result; -} - -static common_chat_tool_call process_tool_call(const json & tool_call) { - const auto & arguments = tool_call.at("arguments"); - return { - /* .name = */ tool_call.at("name"), - /* .arguments = */ arguments.is_string() ? arguments.get() : arguments.dump(), - /* .id = */ tool_call.contains("id") ? tool_call.at("id") : "", + builder.consume_spaces(); + builder.add_content(builder.consume_rest()); }; -} -static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) { - auto content_end = input.find(prefix); - size_t tc_start = std::string::npos; - - common_chat_msg result; - result.role = "assistant"; - if (content_end == std::string::npos) { - result.content = input; - } else { - tc_start = content_end + prefix.size() - rstrip_prefix; - result.content = input.substr(0, content_end); - auto tool_calls = json::parse(input.substr(tc_start)); - for (const auto & tool_call : tool_calls) { - result.tool_calls.emplace_back(process_tool_call(tool_call)); + if (block_open) { + if (auto res = builder.try_find_regex(*block_open)) { + parse_tool_calls(); + } else { + builder.add_content(builder.consume_rest()); } + } else { + parse_tool_calls(); + } +} + +static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder, const common_regex & prefix, size_t rstrip_prefix = 0) { + static const std::vector> args_paths = {{"arguments"}}; + if (auto res = builder.try_find_regex(prefix)) { + builder.move_back(rstrip_prefix); + auto tool_calls = builder.consume_json_with_dumped_args(args_paths); + if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call array"); + } + } else { + builder.add_content(builder.consume_rest()); } - return result; } static void foreach_function(const json & tools, const std::function & fn) { @@ -742,29 +832,36 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp data.format = COMMON_CHAT_FORMAT_GENERIC; return data; } -static common_chat_msg common_chat_parse_generic(const std::string & input) { - json data = json::parse(input); - common_chat_msg result; - result.role = "assistant"; - if (data.contains("tool_calls")) { - for (const auto & tool_call : data.at("tool_calls")) { - result.tool_calls.push_back({ - tool_call.at("name"), - tool_call.at("arguments").dump(), - tool_call.contains("id") ? tool_call.at("id") : "", - }); - } - } else if (data.contains("tool_call")) { - result.tool_calls.push_back({ - data.at("tool_call").at("name"), - data.at("tool_call").at("arguments").dump(), - /* id= */ "", - }); - } else if (data.contains("response")) { - const auto & response = data.at("response"); - result.content = response.is_string() ? response.get() : response.dump(2); +static void common_chat_parse_generic(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + static const std::vector> content_paths = { + {"response"}, + }; + static const std::vector> args_paths = { + {"tool_call", "arguments"}, + {"tool_calls", "arguments"}, + }; + auto data = builder.consume_json_with_dumped_args(args_paths, content_paths); + if (data.value.contains("tool_calls")) { + if (!builder.add_tool_calls(data.value.at("tool_calls")) || data.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool calls"); + } + } else if (data.value.contains("tool_call")) { + if (!builder.add_tool_call(data.value.at("tool_call")) || data.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } else if (data.value.contains("response")) { + const auto & response = data.value.at("response"); + builder.add_content(response.is_string() ? response.template get() : response.dump(2)); + if (data.is_partial) { + throw common_chat_msg_partial_exception("incomplete response"); + } + } else { + throw common_chat_msg_partial_exception("Expected 'tool_call', 'tool_calls' or 'response' in JSON"); } - return result; } static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) { @@ -811,12 +908,44 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO; return data; } -static common_chat_msg common_chat_parse_mistral_nemo(const std::string & input) { - return parse_prefixed_json_tool_call_array(input, "[TOOL_CALLS]"); +static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex prefix(regex_escape("[TOOL_CALLS]")); + parse_prefixed_json_tool_call_array(builder, prefix); } static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; + + auto adjusted_messages = json::array(); + for (const auto & msg : inputs.messages) { + auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string(); + auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array(); + if (has_reasoning_content && has_tool_calls) { + auto adjusted_message = msg; + adjusted_message["tool_plan"] = msg.at("reasoning_content"); + adjusted_message.erase("reasoning_content"); + adjusted_messages.push_back(adjusted_message); + } else { + adjusted_messages.push_back(msg); + } + } + data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {}); + data.format = COMMON_CHAT_FORMAT_COMMAND_R7B; + if (string_ends_with(data.prompt, "<|START_THINKING|>")) { + if (!inputs.enable_thinking) { + data.prompt += "<|END_THINKING|>"; + } else { + data.thinking_forced_open = true; + } + } else if (!inputs.enable_thinking && string_ends_with(data.prompt, "<|CHATBOT_TOKEN|>")) { + data.prompt += "<|START_THINKING|><|END_THINKING|>"; + } + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { auto schemas = json::array(); @@ -847,11 +976,16 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_ if (!inputs.parallel_tool_calls) { schema["maxItems"] = 1; } - builder.add_rule("root", "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\""); + builder.add_rule("root", + std::string(data.thinking_forced_open ? "( \"<|END_THINKING|>\" space )? " : "") + + "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\""); }); data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_WORD, - "<|START_ACTION|>", + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, + // If thinking_forced_open, then we capture the tag in the grammar, + // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) + std::string(data.thinking_forced_open ? "[\\s\\S]*?(<\\|END_THINKING\\|>\\s*)" : "(?:<\\|START_THINKING\\|>[\\s\\S]*?<\\|END_THINKING\\|>\\s*)?") + + "(<\\|START_ACTION\\|>)[\\s\\S]*" }); data.preserved_tokens = { "<|START_ACTION|>", @@ -861,61 +995,40 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_ "<|START_THINKING|>", "<|END_THINKING|>", }; - auto adjusted_messages = json::array(); - for (const auto & msg : inputs.messages) { - auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string(); - auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array(); - if (has_reasoning_content && has_tool_calls) { - auto adjusted_message = msg; - adjusted_message["tool_plan"] = msg.at("reasoning_content"); - adjusted_message.erase("reasoning_content"); - adjusted_messages.push_back(adjusted_message); - } else { - adjusted_messages.push_back(msg); - } - } - data.prompt = apply(tmpl, adjusted_messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, {}); - data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING : COMMON_CHAT_FORMAT_COMMAND_R7B; return data; } -static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool extract_reasoning) { - static const std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S]*?)<\\|END_THINKING\\|>)([\\s\\S]*)"); - static const std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S]*?)<\\|END_ACTION\\|>"); - static const std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S]*?)<\\|END_RESPONSE\\|>"); - std::smatch match; +static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>"); - common_chat_msg result; - result.role = "assistant"; + static const common_regex start_action_regex("<\\|START_ACTION\\|>"); + static const common_regex end_action_regex("<\\|END_ACTION\\|>"); + static const common_regex start_response_regex("<\\|START_RESPONSE\\|>"); + static const common_regex end_response_regex("<\\|END_RESPONSE\\|>"); - std::string rest = input; - - if (std::regex_match(rest, match, thought_regex)) { - if (extract_reasoning) { - result.reasoning_content = match[2].str(); - } else if (!match[2].str().empty()) { - // Let the unparsed thinking tags through in content only if their insides aren't empty. - result.content = match[1].str(); + if (auto res = builder.try_find_regex(start_action_regex)) { + // If we didn't extract thoughts, prelude includes them. + auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}}); + for (const auto & tool_call : tool_calls.value) { + std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : ""; + std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : ""; + std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : ""; + if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } } - rest = match[3].str(); - } - if (std::regex_match(rest, match, action_regex)) { - auto actions_str = match[1].str(); - auto actions = json::parse(actions_str); - for (const auto & action : actions) { - result.tool_calls.push_back({ - /* .name = */ action.at("tool_name"), - /* .arguments = */ action.at("parameters").dump(), - /* .id = */ action.at("tool_call_id"), - }); + if (tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_regex(end_action_regex); + } else if (auto res = builder.try_find_regex(start_response_regex)) { + if (!builder.try_find_regex(end_response_regex)) { + builder.add_content(builder.consume_rest()); + throw common_chat_msg_partial_exception(end_response_regex.str()); } - } else if (std::regex_match(rest, match, response_regex)) { - auto response = match[1].str(); - result.content += response; } else { - result.content += rest; + builder.add_content(builder.consume_rest()); } - return result; } static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector & expected_properties) { @@ -937,152 +1050,143 @@ static void expect_tool_parameters(const std::string & name, const json & parame } } -static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) { +static common_chat_params common_chat_params_init_llama_3_x(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) { auto builtin_tools = json::array(); common_chat_params data; - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; - - auto handle_builtin_tool = [&](const std::string & name, const json & parameters) { - if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") { - // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py - // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py - expect_tool_parameters(name, parameters, {"query"}); - } else if (name == "python" || name == "code_interpreter") { - // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py - expect_tool_parameters(name, parameters, {"code"}); - } else { - return false; - } - - std::vector kvs; - for (const auto & [key, value] : parameters.at("properties").items()) { - kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT - } - - tool_rules.push_back( - builder.add_rule( - name + "-call", - "\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\"")); - builtin_tools.push_back(name); - - return true; - }; - - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - builder.resolve_refs(parameters); - - // https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime - if (allow_python_tag_builtin_tools) { - handle_builtin_tool(name, parameters); - } - tool_rules.push_back( - builder.add_rule( - name + "-call", - "\"{\" space " - "( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? " - " \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space " - " \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " " - "\"}\" space")); - }); - // Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name. - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START, - "\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"", // + name + "\"[\\s\\S]*", - }); - if (!builtin_tools.empty()) { - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"}); - data.preserved_tokens.push_back("<|python_tag|>"); - } - // Allow a few empty lines on top of the usual constrained json schema space rule. - builder.add_rule("root", string_join(tool_rules, " | ")); - }); - data.additional_stops.push_back("<|eom_id|>"); - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, { - {"tools_in_user_message", false}, - {"builtin_tools", builtin_tools.empty() ? json() : builtin_tools}, - }); - data.format = allow_python_tag_builtin_tools && !builtin_tools.empty() - ? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS - : COMMON_CHAT_FORMAT_LLAMA_3_X; - return data; -} -static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) { - // TODO: tighten & simplify the parser, don't accept leading text context. - static const std::regex function_regex( - "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: "); - static const std::regex close_regex("\\}\\s*"); - static const std::regex builtin_call_regex("<\\|python_tag\\|>\\s*([^.(]+)\\s*\\.\\s*call\\s*\\(\\s*([\\w]+)\\s*=\\s*([\\s\\S]*?)\\)"); - - if (with_builtin_tools) { - std::smatch match; - if (std::regex_match(input, match, builtin_call_regex)) { - try { - auto name = match[1].str(); - auto arg_name = match[2].str(); - auto arg_value_str = match[3].str(); - auto arg_value = json::parse(arg_value_str); - - common_chat_msg msg; - msg.role = "assistant"; - msg.tool_calls.push_back({ - /* .name = */ name, - /* .arguments = */ (json { - {arg_name, arg_value}, - }).dump(), - /* .id = */ "", - }); - return msg; - } catch (const std::exception & e) { - LOG_WRN("Failed to parse builtin tool call arguments (%s): %s", e.what(), input.c_str()); - } - } - } - return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex); -} - -static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - if (inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null(); + if (!inputs.tools.is_null()) { + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; + + auto handle_builtin_tool = [&](const std::string & name, const json & parameters) { + if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") { + // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py + // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py + expect_tool_parameters(name, parameters, {"query"}); + } else if (name == "python" || name == "code_interpreter") { + // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py + expect_tool_parameters(name, parameters, {"code"}); + } else { + return false; + } + + std::vector kvs; + for (const auto & [key, value] : parameters.at("properties").items()) { + kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT + } + + tool_rules.push_back( + builder.add_rule( + name + "-call", + "\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\"")); + builtin_tools.push_back(name); + + return true; + }; + foreach_function(inputs.tools, [&](const json & tool) { const auto & function = tool.at("function"); std::string name = function.at("name"); auto parameters = function.at("parameters"); builder.resolve_refs(parameters); - tool_rules.push_back(builder.add_rule(name + "-call", - "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n" - "```json\\n\" " + builder.add_schema(name + "-args", parameters) + " " - "\"```<|tool▁call▁end|>\"")); + + // https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime + if (allow_python_tag_builtin_tools) { + handle_builtin_tool(name, parameters); + } + tool_rules.push_back( + builder.add_rule( + name + "-call", + "\"{\" space " + "( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? " + " \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space " + " \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " " + "\"}\" space")); }); - // Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag, - // so we accept common variants (then it's all constrained) - builder.add_rule("root", - "( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" ) " - "(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " " - "\"<|tool▁calls▁end|>\"" - " space"); - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool▁calls▁begin|>"}); - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_calls_begin|>"}); - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool calls begin|>"}); - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool\\_calls\\_begin|>"}); - data.preserved_tokens = { - "", - "", - "<|tool▁calls▁begin|>", - "<|tool▁call▁begin|>", - "<|tool▁sep|>", - "<|tool▁call▁end|>", - "<|tool▁calls▁end|", - }; + // Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name. + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, + "(\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\")[\\s\\S]*", // + name + "\"[\\s\\S]*", + }); + if (!builtin_tools.empty()) { + data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"}); + data.preserved_tokens.push_back("<|python_tag|>"); + } + // Allow a few empty lines on top of the usual constrained json schema space rule. + builder.add_rule("root", string_join(tool_rules, " | ")); + data.additional_stops.push_back("<|eom_id|>"); }); + data.format = allow_python_tag_builtin_tools && !builtin_tools.empty() + ? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS + : COMMON_CHAT_FORMAT_LLAMA_3_X; + } else { + data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; } + data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, { + {"date_string", format_time(inputs.now, "%d %b %Y")}, + {"tools_in_user_message", false}, + {"builtin_tools", builtin_tools.empty() ? json() : builtin_tools}, + }); + return data; +} +static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex function_regex( + "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: "); + static const common_regex close_regex("\\}\\s*"); + + static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\("); + static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*"); + + if (with_builtin_tools) { + static const common_regex builtin_call_regex("<\\|python_tag\\|>"); + if (auto res = builder.try_find_regex(builtin_call_regex)) { + auto fun_res = builder.consume_regex(function_name_regex); + auto function_name = builder.str(fun_res.groups[1]); + + common_healing_marker healing_marker; + json args = json::object(); + while (true) { + if (auto arg_res = builder.try_consume_regex(arg_name_regex)) { + auto arg_name = builder.str(arg_res->groups[1]); + auto partial = builder.consume_json(); + args[arg_name] = partial.json; + healing_marker.marker = partial.healing_marker.marker; + healing_marker.json_dump_marker = partial.healing_marker.json_dump_marker; + builder.consume_spaces(); + if (!builder.try_consume_literal(",")) { + break; + } + } else { + break; + } + } + builder.consume_literal(")"); + builder.consume_spaces(); + + auto arguments = args.dump(); + if (!builder.add_tool_call(function_name, "", arguments)) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + return; + } + } + parse_json_tool_calls( + builder, + /* block_open= */ std::nullopt, + /* function_regex_start_only= */ function_regex, + /* function_regex= */ std::nullopt, + close_regex, + std::nullopt); + +} + +static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) { + common_chat_params data; auto prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); // Hacks to fix the official (broken) prompt. @@ -1103,52 +1207,83 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_ "$1<|tool▁calls▁end|><|end▁of▁sentence|>$2"); } data.prompt = prompt; - data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING : COMMON_CHAT_FORMAT_DEEPSEEK_R1; + data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1; + if (string_ends_with(data.prompt, "\n")) { + if (!inputs.enable_thinking) { + data.prompt += ""; + } else { + data.thinking_forced_open = true; + } + } + + if (inputs.tools.is_array() && !inputs.tools.empty()) { + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null(); + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); + tool_rules.push_back(builder.add_rule(name + "-call", + "( \"<|tool▁call▁begin|>\" )? \"function<|tool▁sep|>" + name + "\\n" + "```json\\n\" " + builder.add_schema(name + "-args", parameters) + " " + "\"```<|tool▁call▁end|>\"")); + }); + // Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag, + // so we accept common variants (then it's all constrained) + builder.add_rule("root", + std::string(data.thinking_forced_open ? "( \"\" space )? " : "") + + "( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" | \"<|tool▁calls|>\" ) " + "(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " " + "\"<|tool▁calls▁end|>\"" + " space"); + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, + // If thinking_forced_open, then we capture the tag in the grammar, + // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) + std::string(data.thinking_forced_open ? "[\\s\\S]*?(\\s*)" : "(?:[\\s\\S]*?\\s*)?") + + "(<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)[\\s\\S]*" + }); + data.preserved_tokens = { + "", + "", + "<|tool▁calls▁begin|>", + "<|tool▁call▁begin|>", + "<|tool▁sep|>", + "<|tool▁call▁end|>", + "<|tool▁calls▁end|", + }; + }); + } return data; } -static common_chat_msg handle_think_tag_prelude(const std::string & input, bool extract_reasoning, const std::function & rest_parser) { - std::smatch match; - static const std::regex reasoning_content_regex("((?:)?([\\s\\S\\r\\n]*?))?([\\s\\S\\r\\n]*)"); - if (std::regex_match(input, match, reasoning_content_regex)) { - auto rest = match[3].str(); - auto msg = rest_parser(rest); - auto reasoning_content = string_strip(match[2].str()); - if (extract_reasoning) { - msg.reasoning_content = reasoning_content; - } else if (!reasoning_content.empty()) { - std::ostringstream content; - content << "" << reasoning_content << "" << msg.content; - msg.content = content.str(); - } - return msg; +static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; } - return rest_parser(input); -} -static common_chat_msg common_chat_parse_deepseek_r1(const std::string & input, bool extract_reasoning) { - return handle_think_tag_prelude(input, extract_reasoning, [](const std::string & input) { - static const std::regex function_regex("<|tool▁call▁begin|>function<|tool▁sep|>([^\n]+)\n```json\n"); - static const std::regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>"); - static const std::regex tool_calls_regex("[\\s\\r\\n]*(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>)([\\s\\S\\r\\n]*?)<|tool▁calls▁end|>"); - common_chat_msg msg; - msg.role = "assistant"; - std::smatch match; - if (std::regex_search(input, match, tool_calls_regex)) { - auto tool_calls = match[1].str(); - auto msg2 = parse_json_tool_calls(tool_calls, std::nullopt, function_regex, close_regex); - msg.tool_calls = std::move(msg2.tool_calls); - } else { - msg.content = input; - } - return msg; - }); + static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); + static const common_regex tool_calls_end("<|tool▁calls▁end|>"); + static const common_regex function_regex("(?:<|tool▁call▁begin|>)?function<|tool▁sep|>([^\n]+)\n```json\n"); + static const common_regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>"); + + parse_json_tool_calls( + builder, + /* block_open= */ tool_calls_begin, + /* function_regex_start_only= */ std::nullopt, + function_regex, + close_regex, + tool_calls_end); } static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { LOG_DBG("%s\n", __func__); common_chat_params data; data.prompt = apply(tmpl, inputs.messages, /* tools= */ nullptr, inputs.add_generation_prompt, { - {"datetime", "Jan 29 2025 13:00:00 GMT"}, + {"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")}, {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))}, }); if (inputs.tools.is_array() && !inputs.tools.empty()) { @@ -1189,13 +1324,19 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c } return data; } -static common_chat_msg common_chat_parse_firefunction_v2(const std::string & input) { - return parse_prefixed_json_tool_call_array(input, " functools[", /* rstrip_prefix= */ 1); +static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + static const common_regex prefix(regex_escape(" functools[")); + parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1); } static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) { // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar + // If the function is python, we also allow raw python code (if the line after `python\n` doesn't start w/ opening `{`), which the model seems to prefer for multiline code. common_chat_params data; data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2; @@ -1209,24 +1350,21 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ std::string name = function.at("name"); auto parameters = function.at("parameters"); builder.resolve_refs(parameters); + std::string args_pattern = "[\\s\\S]*"; auto args_rule = builder.add_schema(name + "-args", parameters); - first_tool_rules.push_back(builder.add_rule(name + "-call", "( \"assistant<|end_header_id|>\\n\" )? \"" + name + "\\n\" " + args_rule)); - subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule)); + if (name == "python") { + args_rule = builder.add_rule(name + "-maybe-raw-args", args_rule + " | [^{] .*"); + } else { + args_pattern = "\\{" + args_pattern; + } + auto call_rule = builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule); + first_tool_rules.push_back(call_rule); + if (inputs.parallel_tool_calls) { + subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>\" " + call_rule)); + } data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START, - regex_escape(name + "\n"), - }); - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START, - regex_escape("assistant<|end_header_id|>\n" + name + "\n"), - }); - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_WORD, - regex_escape(">>>" + name + "\n"), - }); - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_WORD, - ">>>assistant<|end_header_id|>\n" + name, + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, + "((?:[\\s\\S]+?>>>)?" + regex_escape(name) + "\n)" + args_pattern, }); }); data.preserved_tokens = { @@ -1244,319 +1382,311 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ } return data; } +static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) { + static const common_regex function_regex_start_only(R"((\w+\n\{|python\n|all\n))"); + static const common_regex function_regex(R"(>>>(\w+\n\{|python\n|all\n))"); + static const common_regex close_regex(R"(\s*)"); -static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & input) { - static const std::regex function_regex(R"((?:>>>)?(?:assistant<|end_header_id|>\n)?(\w+)\n)"); - static const std::regex close_regex(R"($|(?=>>>))"); - - std::string content; - auto it = input.begin(); - const auto end = input.end(); - - if (parse_literal(it, end, "all\n")) { - std::smatch match; - if (std::regex_search(it, end, match, function_regex)) { - auto fun_it = match.prefix().second; - content = std::string(it, fun_it); - it = fun_it; - } else { - common_chat_msg res; - res.role = "assistant"; - res.content = std::string(it, end); - return res; - } - } - // TODO: tighten & simplify. - try { - auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex, /* allow_raw_python= */ true); - res.content = content + res.content; - return res; - } catch (const std::exception & e) { - LOG_ERR("Failed to parse functionary v3.2 input: %s\n", e.what()); - common_chat_msg res; - res.role = "assistant"; - res.content = input; - return res; - } + parse_json_tool_calls( + builder, + std::nullopt, + function_regex_start_only, + function_regex, + close_regex, + std::nullopt, + /* allow_raw_python= */ true, + /* get_function_name= */ [&](const auto & res) -> std::string { + auto at_start = res.groups[0].begin == 0; + auto name = builder.str(res.groups[1]); + if (!name.empty() && name.back() == '{') { + // Unconsume the opening brace '{' to ensure the JSON parsing goes well. + builder.move_back(1); + } + auto idx = name.find_last_not_of("\n{"); + name = name.substr(0, idx + 1); + if (at_start && name == "all") { + return ""; + } + return name; + }); } static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) { // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt common_chat_params data; - json tools = inputs.tools.is_null() ? inputs.tools : json::array(); - std::string python_code_argument_name; - auto has_raw_python = false; - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - const auto & parameters = function.at("parameters"); - std::string name = function.at("name"); - if (name == "python" || name == "ipython") { - if (!parameters.contains("type")) { - throw std::runtime_error("Missing type in python tool"); - } - has_raw_python = true; - const auto & type = parameters.at("type"); - if (type == "object") { - auto properties = parameters.at("properties"); - for (auto it = properties.begin(); it != properties.end(); ++it) { - if (it.value().at("type") == "string") { - if (!python_code_argument_name.empty()) { - throw std::runtime_error("Multiple string arguments found in python tool"); + if (!inputs.tools.is_null()) { + std::string python_code_argument_name; + auto has_raw_python = false; + + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + const auto & parameters = function.at("parameters"); + std::string name = function.at("name"); + if (name == "python" || name == "ipython") { + if (!parameters.contains("type")) { + throw std::runtime_error("Missing type in python tool"); + } + has_raw_python = true; + const auto & type = parameters.at("type"); + if (type == "object") { + auto properties = parameters.at("properties"); + for (auto it = properties.begin(); it != properties.end(); ++it) { + if (it.value().at("type") == "string") { + if (!python_code_argument_name.empty()) { + throw std::runtime_error("Multiple string arguments found in python tool"); + } + python_code_argument_name = it.key(); } - python_code_argument_name = it.key(); } + if (python_code_argument_name.empty()) { + throw std::runtime_error("No string argument found in python tool"); + } + } else if (type != "string") { + throw std::runtime_error("Invalid type in python tool: " + type.dump()); } - if (python_code_argument_name.empty()) { - throw std::runtime_error("No string argument found in python tool"); - } - } else if (type != "string") { - throw std::runtime_error("Invalid type in python tool: " + type.dump()); } + tool_rules.push_back(builder.add_rule(name + "-call", "\"\" " + builder.add_schema(name + "-args", parameters) + " \"\" space")); + }); + if (has_raw_python) { + tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*")); + data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"}); + data.preserved_tokens.push_back("<|python_tag|>"); } - tool_rules.push_back(builder.add_rule(name + "-call", "\"\" " + builder.add_schema(name + "-args", parameters) + " \"\" space")); + auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space"; + builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); + data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "\" .*")); - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"}); - data.preserved_tokens.push_back("<|python_tag|>"); - } - auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space"; - builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "([\s\S\n]*)$)"); - std::smatch match; - if (std::regex_search(input, match, python_tag_regex)) { - auto code = match[1].str(); - common_chat_msg msg; - msg.role = "assistant"; - msg.content = match.prefix().str(); - msg.tool_calls.push_back({ - /* .name = */ "python", - /* .arguments = */ (json {{"code", code}}).dump(), - /* .id = */ "", - }); - return msg; +static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + // This version of Functionary still supports the llama 3.1 tool call format for the python tool. + static const common_regex python_tag_regex(regex_escape("<|python_tag|>")); + + static const common_regex function_regex(R"()"); + static const common_regex close_regex(R"()"); + + parse_json_tool_calls( + builder, + /* block_open= */ std::nullopt, + /* function_regex_start_only= */ std::nullopt, + function_regex, + close_regex, + std::nullopt); + + if (auto res = builder.try_find_regex(python_tag_regex)) { + auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); + builder.add_tool_call("python", "", arguments); + return; } - static const std::regex function_regex(R"()"); - static const std::regex close_regex(R"()"); - // TODO: tighten & simplify. - return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex); } static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; - // (content)?({"name": "foo", "arguments": {"a": 1}})* - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; - std::vector tool_call_alts; - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - builder.resolve_refs(parameters); - tool_rules.push_back(builder.add_schema(name + "-call", { - {"type", "object"}, - {"properties", json { - {"name", json {{"const", name}}}, - {"arguments", parameters}, - }}, - {"required", json::array({"name", "arguments"})}, - })); - tool_call_alts.push_back(builder.add_rule( - name + "-function-tag", - "\"\" space " + - builder.add_schema(name + "-args", parameters) + " " - "\"\" space")); - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_WORD, - "", - }); - auto escaped_name = regex_escape(name); - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, - " alt_tags { - any_tool_call, - "\"\" space " + any_tool_call + " \"\"", - // The rest is just to accommodate common "good bad" outputs. - "\"\" space " + any_tool_call + " \"\"", - "\"\" space " + any_tool_call + " \"\"", - "\"\" space " + any_tool_call + " \"\"", - "\"\" space " + any_tool_call + " \"\"", - "\"\" space " + any_tool_call + " \"\"", - "\"\" space " + any_tool_call + " \"\"", - }; - auto wrappable_tool_call = builder.add_rule("wrappable_tool_call", "( " + string_join(alt_tags, " | ") + " ) space"); - tool_call_alts.push_back(wrappable_tool_call); - tool_call_alts.push_back( - "( \"```\\n\" | \"```json\\n\" | \"```xml\\n\" ) space " + wrappable_tool_call + " space \"```\" space "); - auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | ")); - builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, ""}); - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "|||)?\\s*\\{\\s*\"", //name\"\\s*:\\s*\"" + escaped_name + "\"", - }); - data.preserved_tokens = { - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "```", - "```json", - "```xml", - }; - }); + json additional_context = { + {"enable_thinking", inputs.enable_thinking}, + }; + + data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt, additional_context); + data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO; + if (string_ends_with(data.prompt, "\n")) { + if (!inputs.enable_thinking) { + data.prompt += ""; + } else { + data.thinking_forced_open = true; + } + } + + if (!inputs.tools.is_null()) { + // (content)?({"name": "foo", "arguments": {"a": 1}})* + data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + std::vector tool_rules; + std::vector tool_call_alts; + std::vector escaped_names; + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); + tool_rules.push_back(builder.add_schema(name + "-call", { + {"type", "object"}, + {"properties", json { + {"name", json {{"const", name}}}, + {"arguments", parameters}, + }}, + {"required", json::array({"name", "arguments"})}, + })); + tool_call_alts.push_back(builder.add_rule( + name + "-function-tag", + "\"\" space " + + builder.add_schema(name + "-args", parameters) + " " + "\"\" space")); + + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_WORD, + "", + }); + auto escaped_name = regex_escape(name); + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, + " alt_tags { + any_tool_call, + "\"\" space " + any_tool_call + " \"\"", + // The rest is just to accommodate common "good bad" outputs. + "\"\" space " + any_tool_call + " \"\"", + "\"\" space " + any_tool_call + " \"\"", + "\"\" space " + any_tool_call + " \"\"", + "\"\" space " + any_tool_call + " \"\"", + "\"\" space " + any_tool_call + " \"\"", + "\"\" space " + any_tool_call + " \"\"", + }; + auto wrappable_tool_call = builder.add_rule("wrappable_tool_call", "( " + string_join(alt_tags, " | ") + " ) space"); + tool_call_alts.push_back(wrappable_tool_call); + tool_call_alts.push_back( + "( \"```\\n\" | \"```json\\n\" | \"```xml\\n\" ) space " + wrappable_tool_call + " space \"```\" space "); + auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | ")); + builder.add_rule("root", + std::string(data.thinking_forced_open ? "( \"\" space )? " : "") + + (inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call)); + // Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives) + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, + // If thinking_forced_open, then we capture the tag in the grammar, + // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) + std::string(data.thinking_forced_open ? "[\\s\\S]*?(\\s*)" : "(?:[\\s\\S]*?\\s*)?") + ( + "(\\s*" + "(?:" + "||||)?" + "\\s*\\{\\s*\"name\"\\s*:\\s*\"(?:" + string_join(escaped_names, "|") + ")\"" + ")" + ")[\\s\\S]*" + ), + }); + data.preserved_tokens = { + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "```", + "```json", + "```xml", + }; + }); + } - data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); - data.format = inputs.extract_reasoning ? COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING : COMMON_CHAT_FORMAT_HERMES_2_PRO; return data; } -static common_chat_msg common_chat_parse_hermes_2_pro(const std::string& input, bool extract_reasoning) { - return handle_think_tag_prelude(input, extract_reasoning, [](const std::string & input) { - static const std::regex open_regex( - "(?:" - "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start) - "(" // match 2 (open_tag) - "|" - "|" - "|" - "|" - "|" - "|" - "|" +static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex open_regex( + "(?:" + "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start) + "(" // match 2 (open_tag) + "" + "|" + "|" + "|" + "|" + "|" + "|" + "|" ")?" - "(\\s*\\{\\s*\"name\"\\s*:[\\s\\S]*)" // match 3 (named tool call + rest) - ")" - "|" - "(?:]+)>" // match 4 (function name) - "|)" // match 5 (function name again) - "([\\s\\S]*)" // match 6 (function arguments + rest)})" - ); + "(\\s*\\{\\s*\"name\")" // match 3 (named tool call) + ")" + "|]+)>" // match 4 (function name) + "|" // match 5 (function name again) + ); - try { - common_chat_msg msg; - msg.role = "assistant"; + if (auto res = builder.try_find_regex(open_regex)) { + const auto & block_start = res->groups[1]; + std::string block_end = block_start.empty() ? "" : "```"; - std::string::const_iterator it = input.begin(); - const std::string::const_iterator end = input.end(); - std::smatch match; + const auto & open_tag = res->groups[2]; + std::string close_tag; - while (it != end) { - if (std::regex_search(it, end, match, open_regex)) { - // Add content before the match - msg.content += std::string(it, match[0].first); + if (!res->groups[3].empty()) { + builder.move_to(res->groups[3].begin); + close_tag = open_tag.empty() ? "" : "value) || tool_call->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_spaces(); + builder.consume_literal(close_tag); + builder.consume_spaces(); + if (!block_end.empty()) { + builder.consume_literal(block_end); + builder.consume_spaces(); + } + builder.add_content(builder.consume_rest()); + } else { + throw common_chat_msg_partial_exception("failed to parse tool call"); + } + } else { + auto function_name = builder.str(res->groups[4]); + if (function_name.empty()) { + function_name = builder.str(res->groups[5]); + } + GGML_ASSERT(!function_name.empty()); - auto open_tag = match[2].str(); - std::string close_tag; + close_tag = ""; - if (match[3].matched) { - close_tag = open_tag.empty() ? "" : ""; - // Start parsing from after the opening tags - auto json_it = match[6].first; - json arguments; - if (parse_json(json_it, end, arguments)) { - msg.tool_calls.emplace_back(process_tool_call({ - {"name", function_name}, - {"arguments", arguments}, - })); - it = json_it; // Move iterator past parsed JSON - - // Handle close tags - consume_spaces(it, end); - if (!close_tag.empty() && !parse_literal(it, end, close_tag)) { - throw std::runtime_error("Failed to parse closing tag"); - } - consume_spaces(it, end); - if (!block_end.empty() && !parse_literal(it, end, block_end)) { - throw std::runtime_error("Failed to parse block end"); - } - consume_spaces(it, end); - } else { - // Not a valid tool call, treat as content - msg.content += std::string(match[0].first, match[0].second); - it = match[0].second; - } - } - } else { - // Add remaining content - msg.content += std::string(it, end); - break; + if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { + if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_spaces(); + builder.consume_literal(close_tag); + builder.consume_spaces(); + if (!block_end.empty()) { + builder.consume_literal(block_end); + builder.consume_spaces(); } } - return msg; - } catch (const std::exception & e) { - LOG_ERR("Failed to parse hermes 2 pro input: %s\n", e.what()); - common_chat_msg msg; - msg.role = "assistant"; - msg.content = input; - return msg; + builder.add_content(builder.consume_rest()); } - }); + } else { + builder.add_content(builder.consume_rest()); + } } static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { @@ -1588,9 +1718,10 @@ static common_chat_params common_chat_templates_apply_jinja( const auto & caps = tmpl.original_caps(); params.messages = common_chat_msgs_to_json_oaicompat(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content); params.add_generation_prompt = inputs.add_generation_prompt; - params.extract_reasoning = inputs.extract_reasoning; params.tool_choice = inputs.tool_choice; + params.enable_thinking = inputs.enable_thinking; params.grammar = inputs.grammar; + params.now = inputs.now; if (!inputs.json_schema.empty()) { params.json_schema = json::parse(inputs.json_schema); } @@ -1622,7 +1753,7 @@ static common_chat_params common_chat_templates_apply_jinja( } // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools) - if (src.find("") != std::string::npos && params.json_schema.is_null() && params.tools.is_array() && params.json_schema.is_null()) { + if (src.find("") != std::string::npos && params.json_schema.is_null()) { return common_chat_params_init_hermes_2_pro(tmpl, params); } @@ -1642,21 +1773,21 @@ static common_chat_params common_chat_templates_apply_jinja( return common_chat_params_init_firefunction_v2(tmpl, params); } - // Plain handler (no tools) - if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { - return common_chat_params_init_without_tools(tmpl, params); - } - // Functionary v3.1 (w/ tools) if (src.find("<|start_header_id|>") != std::string::npos && src.find("ipython<|end_header_id|>") != std::string::npos) { auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos; - return common_chat_params_init_llama_3_1_tool_calls(tmpl, params, allow_python_tag_builtin_tools); + return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools); + } + + // Plain handler (no tools) + if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { + return common_chat_params_init_without_tools(tmpl, params); } // Mistral Nemo (w/ tools) @@ -1707,7 +1838,7 @@ static common_chat_params common_chat_templates_apply_legacy( if (res < 0) { // if the custom "tmpl" is not supported, we throw an error // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() - throw std::runtime_error("this custom template is not supported"); + throw std::runtime_error("this custom template is not supported, try using --jinja"); } // if it turns out that our buffer is too small, we resize it @@ -1736,44 +1867,66 @@ common_chat_params common_chat_templates_apply( : common_chat_templates_apply_legacy(tmpls, inputs); } -static common_chat_msg common_chat_parse_content_only(const std::string & input) { - common_chat_msg msg; - msg.role = "assistant"; - msg.content = input; - return msg; +static void common_chat_parse_content_only(common_chat_msg_parser & builder) { + builder.add_content(builder.consume_rest()); } -common_chat_msg common_chat_parse(const std::string & input, common_chat_format format) { - switch (format) { +static void common_chat_parse(common_chat_msg_parser & builder) { + LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str()); + + switch (builder.syntax().format) { case COMMON_CHAT_FORMAT_CONTENT_ONLY: - return common_chat_parse_content_only(input); + common_chat_parse_content_only(builder); + break; case COMMON_CHAT_FORMAT_GENERIC: - return common_chat_parse_generic(input); + common_chat_parse_generic(builder); + break; case COMMON_CHAT_FORMAT_MISTRAL_NEMO: - return common_chat_parse_mistral_nemo(input); + common_chat_parse_mistral_nemo(builder); + break; case COMMON_CHAT_FORMAT_LLAMA_3_X: - return common_chat_parse_llama_3_1(input); + common_chat_parse_llama_3_1(builder); + break; case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: - return common_chat_parse_llama_3_1(input, /* with_builtin_tools= */ true); + common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true); + break; case COMMON_CHAT_FORMAT_DEEPSEEK_R1: - return common_chat_parse_deepseek_r1(input, /* extract_reasoning= */ false); - case COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING: - return common_chat_parse_deepseek_r1(input, /* extract_reasoning= */ true); + common_chat_parse_deepseek_r1(builder); + break; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: - return common_chat_parse_functionary_v3_2(input); + common_chat_parse_functionary_v3_2(builder); + break; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: - return common_chat_parse_functionary_v3_1_llama_3_1(input); + common_chat_parse_functionary_v3_1_llama_3_1(builder); + break; case COMMON_CHAT_FORMAT_HERMES_2_PRO: - return common_chat_parse_hermes_2_pro(input, /* extract_reasoning= */ false); - case COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING: - return common_chat_parse_hermes_2_pro(input, /* extract_reasoning= */ true); + common_chat_parse_hermes_2_pro(builder); + break; case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: - return common_chat_parse_firefunction_v2(input); + common_chat_parse_firefunction_v2(builder); + break; case COMMON_CHAT_FORMAT_COMMAND_R7B: - return common_chat_parse_command_r7b(input, /* extract_reasoning= */ false); - case COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING: - return common_chat_parse_command_r7b(input, /* extract_reasoning= */ true); + common_chat_parse_command_r7b(builder); + break; default: - throw std::runtime_error("Unsupported format: " + common_chat_format_name(format)); + throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); } + builder.finish(); +} + +common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) { + common_chat_msg_parser builder(input, is_partial, syntax); + try { + common_chat_parse(builder); + } catch (const common_chat_msg_partial_exception & ex) { + LOG_DBG("Partial parse: %s\n", ex.what()); + if (!is_partial) { + builder.clear_tools(); + builder.move_to(0); + common_chat_parse_content_only(builder); + } + } + auto msg = builder.result(); + LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str()); + return msg; } diff --git a/common/chat.h b/common/chat.h index 9aad84e88..9f59e6b08 100644 --- a/common/chat.h +++ b/common/chat.h @@ -3,6 +3,8 @@ #pragma once #include "common.h" +#include +#include #include #include @@ -12,11 +14,19 @@ struct common_chat_tool_call { std::string name; std::string arguments; std::string id; + + bool operator==(const common_chat_tool_call & other) const { + return name == other.name && arguments == other.arguments && id == other.id; + } }; struct common_chat_msg_content_part { std::string type; std::string text; + + bool operator==(const common_chat_msg_content_part & other) const { + return type == other.type && text == other.text; + } }; struct common_chat_msg { @@ -27,6 +37,51 @@ struct common_chat_msg { std::string reasoning_content; std::string tool_name; std::string tool_call_id; + + template T to_json_oaicompat() const; + + bool empty() const { + return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty(); + } + void ensure_tool_call_ids_set(std::vector & ids_cache, const std::function & gen_tool_call_id) { + for (auto i = 0u; i < tool_calls.size(); i++) { + if (ids_cache.size() <= i) { + auto id = tool_calls[i].id; + if (id.empty()) { + id = gen_tool_call_id(); + } + ids_cache.push_back(id); + } + tool_calls[i].id = ids_cache[i]; + } + } + bool operator==(const common_chat_msg & other) const { + return role == other.role + && content == other.content + && content_parts == other.content_parts + && tool_calls == other.tool_calls + && reasoning_content == other.reasoning_content + && tool_name == other.tool_name + && tool_call_id == other.tool_call_id; + } + bool operator!=(const common_chat_msg & other) const { + return !(*this == other); + } +}; + +struct common_chat_msg_diff { + std::string reasoning_content_delta; + std::string content_delta; + size_t tool_call_index = std::string::npos; + common_chat_tool_call tool_call_delta; + + static std::vector compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg); + + bool operator==(const common_chat_msg_diff & other) const { + return content_delta == other.content_delta + && tool_call_index == other.tool_call_index + && tool_call_delta == other.tool_call_delta; + } }; struct common_chat_tool { @@ -48,14 +103,11 @@ enum common_chat_format { COMMON_CHAT_FORMAT_LLAMA_3_X, COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, COMMON_CHAT_FORMAT_DEEPSEEK_R1, - COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, COMMON_CHAT_FORMAT_FIREFUNCTION_V2, COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, COMMON_CHAT_FORMAT_HERMES_2_PRO, - COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING, COMMON_CHAT_FORMAT_COMMAND_R7B, - COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats }; @@ -70,7 +122,9 @@ struct common_chat_templates_inputs { std::vector tools; common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; bool parallel_tool_calls = false; - bool extract_reasoning = true; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; + bool enable_thinking = true; + std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); }; struct common_chat_params { @@ -78,11 +132,21 @@ struct common_chat_params { std::string prompt; std::string grammar; bool grammar_lazy = false; + bool thinking_forced_open = false; std::vector grammar_triggers; std::vector preserved_tokens; std::vector additional_stops; }; +struct common_chat_syntax { + common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; + // Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode) + bool reasoning_in_content = false; + bool thinking_forced_open = false; + bool parse_tool_calls = true; +}; + // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); @@ -119,8 +183,9 @@ std::string common_chat_format_example( const struct common_chat_templates * tmpls, bool use_jinja); -std::string common_chat_format_name(common_chat_format format); -common_chat_msg common_chat_parse( const std::string & input, common_chat_format format); +const char* common_chat_format_name(common_chat_format format); +const char* common_reasoning_format_name(common_reasoning_format format); +common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax); common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice); @@ -133,3 +198,5 @@ template T common_chat_msgs_to_json_oaicompat(const std::vector std::vector common_chat_tools_parse_oaicompat(const T & tools); template T common_chat_tools_to_json_oaicompat(const std::vector & tools); + +template T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff); diff --git a/common/cmake/build-info-gen-cpp.cmake b/common/cmake/build-info-gen-cpp.cmake deleted file mode 100644 index fbc92b52c..000000000 --- a/common/cmake/build-info-gen-cpp.cmake +++ /dev/null @@ -1,24 +0,0 @@ -include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake) - -set(TEMPLATE_FILE "${CMAKE_CURRENT_SOURCE_DIR}/common/build-info.cpp.in") -set(OUTPUT_FILE "${CMAKE_CURRENT_SOURCE_DIR}/common/build-info.cpp") - -# Only write the build info if it changed -if(EXISTS ${OUTPUT_FILE}) - file(READ ${OUTPUT_FILE} CONTENTS) - string(REGEX MATCH "LLAMA_COMMIT = \"([^\"]*)\";" _ ${CONTENTS}) - set(OLD_COMMIT ${CMAKE_MATCH_1}) - string(REGEX MATCH "LLAMA_COMPILER = \"([^\"]*)\";" _ ${CONTENTS}) - set(OLD_COMPILER ${CMAKE_MATCH_1}) - string(REGEX MATCH "LLAMA_BUILD_TARGET = \"([^\"]*)\";" _ ${CONTENTS}) - set(OLD_TARGET ${CMAKE_MATCH_1}) - if ( - NOT OLD_COMMIT STREQUAL BUILD_COMMIT OR - NOT OLD_COMPILER STREQUAL BUILD_COMPILER OR - NOT OLD_TARGET STREQUAL BUILD_TARGET - ) - configure_file(${TEMPLATE_FILE} ${OUTPUT_FILE}) - endif() -else() - configure_file(${TEMPLATE_FILE} ${OUTPUT_FILE}) -endif() diff --git a/common/common.cpp b/common/common.cpp index 94f545f81..eb80cee08 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -203,6 +203,7 @@ bool set_process_priority(enum ggml_sched_priority prio) { DWORD p = NORMAL_PRIORITY_CLASS; switch (prio) { + case GGML_SCHED_PRIO_LOW: p = BELOW_NORMAL_PRIORITY_CLASS; break; case GGML_SCHED_PRIO_NORMAL: p = NORMAL_PRIORITY_CLASS; break; case GGML_SCHED_PRIO_MEDIUM: p = ABOVE_NORMAL_PRIORITY_CLASS; break; case GGML_SCHED_PRIO_HIGH: p = HIGH_PRIORITY_CLASS; break; @@ -228,6 +229,7 @@ bool set_process_priority(enum ggml_sched_priority prio) { int p = 0; switch (prio) { + case GGML_SCHED_PRIO_LOW: p = 5; break; case GGML_SCHED_PRIO_NORMAL: p = 0; break; case GGML_SCHED_PRIO_MEDIUM: p = -5; break; case GGML_SCHED_PRIO_HIGH: p = -10; break; @@ -443,9 +445,28 @@ void string_replace_all(std::string & s, const std::string & search, const std:: s = std::move(builder); } +bool string_ends_with(const std::string_view & str, const std::string_view & suffix) { + return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; +} +size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) { + if (!str.empty() && !stop.empty()) { + const char text_last_char = str.back(); + for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) { + if (stop[char_index] == text_last_char) { + const auto current_partial = stop.substr(0, char_index + 1); + if (string_ends_with(str, current_partial)) { + return str.size() - char_index - 1; + } + } + } + } + + return std::string::npos; +} + std::string regex_escape(const std::string & s) { static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]"); - return std::regex_replace(s, special_chars, "\\$0"); + return std::regex_replace(s, special_chars, "\\$&"); } std::string string_join(const std::vector & values, const std::string & separator) { @@ -746,6 +767,9 @@ bool fs_validate_filename(const std::string & filename) { return true; } +#include + + // returns true if successful, false otherwise bool fs_create_directory_with_parents(const std::string & path) { #ifdef _WIN32 @@ -763,9 +787,16 @@ bool fs_create_directory_with_parents(const std::string & path) { // process path from front to back, procedurally creating directories while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) { const std::wstring subpath = wpath.substr(0, pos_slash); - const wchar_t * test = subpath.c_str(); - const bool success = CreateDirectoryW(test, NULL); + pos_slash += 1; + + // skip the drive letter, in some systems it can return an access denied error + if (subpath.length() == 2 && subpath[1] == ':') { + continue; + } + + const bool success = CreateDirectoryW(subpath.c_str(), NULL); + if (!success) { const DWORD error = GetLastError(); @@ -779,8 +810,6 @@ bool fs_create_directory_with_parents(const std::string & path) { return false; } } - - pos_slash += 1; } return true; @@ -830,7 +859,7 @@ std::string fs_get_cache_directory() { if (getenv("LLAMA_CACHE")) { cache_directory = std::getenv("LLAMA_CACHE"); } else { -#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) +#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__) if (std::getenv("XDG_CACHE_HOME")) { cache_directory = std::getenv("XDG_CACHE_HOME"); } else { @@ -876,31 +905,6 @@ struct common_init_result common_init_from_params(common_params & params) { const llama_vocab * vocab = llama_model_get_vocab(model); - if (params.reranking) { - bool ok = true; - - if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) { - LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__); - ok = false; - } - - if (llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) { - LOG_WRN("%s: warning: vocab does not have an EOS token, reranking will not work\n", __func__); - ok = false; - } - - if (llama_vocab_sep(vocab) == LLAMA_TOKEN_NULL) { - LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__); - ok = false; - } - - if (!ok) { - llama_model_free(model); - - return iparams; - } - } - auto cparams = common_context_params_to_llama(params); llama_context * lctx = llama_init_from_model(model, cparams); @@ -910,7 +914,7 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } - if (params.ctx_shift && !llama_kv_self_can_shift(lctx)) { + if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) { LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__); params.ctx_shift = false; } @@ -942,6 +946,35 @@ struct common_init_result common_init_from_params(common_params & params) { } } + if (llama_pooling_type(lctx) == LLAMA_POOLING_TYPE_RANK) { + bool ok = true; + + if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__); + ok = false; + } + + bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL; + bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL; + + if (!has_eos && !has_sep) { + LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__); + ok = false; + } else if (!has_eos) { + LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__); + } else if (!has_sep) { + LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__); + ok = false; + } + + if (!ok) { + llama_free(lctx); + llama_model_free(model); + + return iparams; + } + } + // load and optionally apply lora adapters for (auto & la : params.lora_adapters) { llama_adapter_lora_ptr lora; @@ -1017,7 +1050,7 @@ struct common_init_result common_init_from_params(common_params & params) { if (llama_model_has_decoder(model)) { llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch))); } - llama_kv_self_clear(lctx); + llama_memory_clear(llama_get_memory(lctx), true); llama_synchronize(lctx); llama_perf_context_reset(lctx); llama_set_warmup(lctx, false); @@ -1083,6 +1116,9 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { mparams.tensor_buft_overrides = params.tensor_buft_overrides.data(); } + mparams.progress_callback = params.load_progress_callback; + mparams.progress_callback_user_data = params.load_progress_callback_user_data; + return mparams; } @@ -1096,7 +1132,6 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.n_threads = params.cpuparams.n_threads; cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ? params.cpuparams.n_threads : params.cpuparams_batch.n_threads; - cparams.logits_all = params.logits_all; cparams.embeddings = params.embedding; cparams.rope_scaling_type = params.rope_scaling_type; cparams.rope_freq_base = params.rope_freq_base; @@ -1114,11 +1149,8 @@ struct llama_context_params common_context_params_to_llama(const common_params & cparams.offload_kqv = !params.no_kv_offload; cparams.flash_attn = params.flash_attn; cparams.no_perf = params.no_perf; - - if (params.reranking) { - cparams.embeddings = true; - cparams.pooling_type = LLAMA_POOLING_TYPE_RANK; - } + cparams.op_offload = !params.no_op_offload; + cparams.swa_full = params.swa_full; cparams.type_k = params.cache_type_k; cparams.type_v = params.cache_type_v; @@ -1306,81 +1338,6 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto return text; } -// -// KV cache utils -// - -void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size) { - static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+"; - - printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d", - view.n_cells, view.n_seq_max, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx); - - llama_kv_cache_view_cell * c_curr = view.cells; - llama_seq_id * cs_curr = view.cells_sequences; - - for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) { - if (i % row_size == 0) { - printf("\n%5d: ", i); - } - int seq_count = 0; - for (int j = 0; j < view.n_seq_max; j++) { - if (cs_curr[j] >= 0) { seq_count++; } - } - putchar(slot_chars[std::min(sizeof(slot_chars) - 2, size_t(seq_count))]); - } - - printf("\n=== Done dumping\n"); -} - -void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size) { - static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; - - printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n", - view.n_cells, view.n_seq_max, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx); - - std::unordered_map seqs; - llama_kv_cache_view_cell * c_curr = view.cells; - llama_seq_id * cs_curr = view.cells_sequences; - - for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) { - for (int j = 0; j < view.n_seq_max; j++) { - if (cs_curr[j] < 0) { continue; } - if (seqs.find(cs_curr[j]) == seqs.end()) { - if (seqs.size() + 1 >= sizeof(slot_chars)) { break; } - const size_t sz = seqs.size(); - seqs[cs_curr[j]] = sz; - } - } - if (seqs.size() + 1 >= sizeof(slot_chars)) { break; } - } - - printf("=== Sequence legend: "); - for (const auto & it : seqs) { - printf("%zu=%d, ", it.second, it.first); - } - printf("'+'=other sequence ids"); - - c_curr = view.cells; - cs_curr = view.cells_sequences; - for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_seq_max) { - if (i % row_size == 0) { - printf("\n%5d: ", i); - } - for (int j = 0; j < view.n_seq_max; j++) { - if (cs_curr[j] >= 0) { - const auto & it = seqs.find(cs_curr[j]); - putchar(it != seqs.end() ? int(slot_chars[it->second]) : '+'); - } else { - putchar('.'); - } - } - putchar(' '); - } - - printf("\n=== Done dumping\n"); -} - // // Embedding utils // @@ -1565,3 +1522,20 @@ common_control_vector_data common_control_vector_load(const std::vector & tokens, int64_t stride) { + const int64_t ne_datapoint = llama_n_ctx(ctx); + const int64_t ndata = (tokens.size() - ne_datapoint - 1) / stride; + ggml_opt_dataset_t result = ggml_opt_dataset_init( + GGML_TYPE_I32, GGML_TYPE_I32, ne_datapoint, ne_datapoint, ndata, /*ndata_shard =*/ 1); + + llama_token * data = (llama_token *) ggml_opt_dataset_data(result)->data; + llama_token * labels = (llama_token *) ggml_opt_dataset_labels(result)->data; + + for (int64_t idata = 0; idata < ndata; ++idata) { + memcpy(data + idata*ne_datapoint, tokens.data() + idata*stride + 0, ne_datapoint*sizeof(llama_token)); + memcpy(labels + idata*ne_datapoint, tokens.data() + idata*stride + 1, ne_datapoint*sizeof(llama_token)); + } + + return result; +} diff --git a/common/common.h b/common/common.h index e6eaa8e80..00b6ca03a 100644 --- a/common/common.h +++ b/common/common.h @@ -6,6 +6,7 @@ #include #include +#include #include #include @@ -66,7 +67,6 @@ enum llama_example { LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_MAIN, - LLAMA_EXAMPLE_INFILL, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_RETRIEVAL, @@ -76,7 +76,7 @@ enum llama_example { LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, - LLAMA_EXAMPLE_LLAVA, + LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_PARALLEL, LLAMA_EXAMPLE_TTS, @@ -96,6 +96,7 @@ enum common_sampler_type { COMMON_SAMPLER_TYPE_XTC = 8, COMMON_SAMPLER_TYPE_INFILL = 9, COMMON_SAMPLER_TYPE_PENALTIES = 10, + COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11, }; // dimensionality reduction methods, used by cvector-generator @@ -114,7 +115,7 @@ enum common_grammar_trigger_type { COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN, COMMON_GRAMMAR_TRIGGER_TYPE_WORD, COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START, + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, }; struct common_grammar_trigger { @@ -161,6 +162,7 @@ struct common_params_sampling { std::vector samplers = { COMMON_SAMPLER_TYPE_PENALTIES, COMMON_SAMPLER_TYPE_DRY, + COMMON_SAMPLER_TYPE_TOP_N_SIGMA, COMMON_SAMPLER_TYPE_TOP_K, COMMON_SAMPLER_TYPE_TYPICAL_P, COMMON_SAMPLER_TYPE_TOP_P, @@ -213,7 +215,8 @@ struct common_params_vocoder { enum common_reasoning_format { COMMON_REASONING_FORMAT_NONE, - COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content` + COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in tags in stream mode + COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas. }; struct common_params { @@ -289,6 +292,7 @@ struct common_params { int32_t verbosity = 0; int32_t control_vector_layer_start = -1; // layer range for control vector int32_t control_vector_layer_end = -1; // layer range for control vector + bool offline = false; int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line @@ -321,17 +325,17 @@ struct common_params { bool flash_attn = false; // flash attention bool no_perf = false; // disable performance metrics bool ctx_shift = true; // context shift on inifinite text generation + bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix - bool logits_all = false; // return logits for all tokens in the batch bool use_mmap = true; // use mmap for faster loads bool use_mlock = false; // use mlock to keep model in memory bool verbose_prompt = false; // print prompt tokens before generation bool display_prompt = true; // print prompt before generation - bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes bool no_kv_offload = false; // disable KV offloading bool warmup = true; // warmup run bool check_tensors = false; // validate tensor data + bool no_op_offload = false; // globally disable offload host tensor operations to device bool single_turn = false; // single turn chat conversation @@ -340,8 +344,10 @@ struct common_params { common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO; - // multimodal models (see examples/llava) + // multimodal models (see tools/mtmd) struct common_params_model mmproj; + bool mmproj_use_gpu = true; // use GPU for multimodal model + bool no_mmproj = false; // explicitly disable multimodal model std::vector image; // path to image file(s) // embedding @@ -349,7 +355,6 @@ struct common_params { int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix std::string embd_sep = "\n"; // separator of embeddings - bool reranking = false; // enable reranking support on server // server params int32_t port = 8080; // server listens on this network port @@ -364,6 +369,8 @@ struct common_params { bool use_jinja = false; // NOLINT bool enable_chat_template = true; common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + int reasoning_budget = -1; + bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response std::vector api_keys; @@ -407,13 +414,14 @@ struct common_params { bool process_output = false; // collect data for the output tensor bool compute_ppl = true; // whether to compute perplexity + bool parse_special = false; // whether to parse special tokens during imatrix tokenization // cvector-generator params int n_pca_batch = 100; int n_pca_iterations = 1000; dimre_method cvector_dimre_method = DIMRE_METHOD_PCA; - std::string cvector_positive_file = "examples/cvector-generator/positive.txt"; - std::string cvector_negative_file = "examples/cvector-generator/negative.txt"; + std::string cvector_positive_file = "tools/cvector-generator/positive.txt"; + std::string cvector_negative_file = "tools/cvector-generator/negative.txt"; bool spm_infill = false; // suffix/prefix/middle pattern for infill @@ -422,6 +430,11 @@ struct common_params { // common params std::string out_file; // output filename for all example programs + // optional callback for model loading progress and cancellation: + // called with a progress value between 0.0 and 1.0. + // return false from callback to abort model loading or true to continue + llama_progress_callback load_progress_callback = NULL; + void * load_progress_callback_user_data = NULL; }; // call once at the start of a program if it uses libcommon @@ -499,10 +512,9 @@ static bool string_starts_with(const std::string & str, return str.rfind(prefix, 0) == 0; } -static bool string_ends_with(const std::string & str, - const std::string & suffix) { // While we wait for C++20's std::string::ends_with... - return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; -} +// While we wait for C++20's std::string::ends_with... +bool string_ends_with(const std::string_view & str, const std::string_view & suffix); +size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop); bool string_parse_kv_override(const char * data, std::vector & overrides); void string_process_escapes(std::string & input); @@ -611,16 +623,6 @@ std::string common_detokenize( const std::vector & tokens, bool special = true); -// -// KV cache utils -// - -// Dump the KV cache view with the number of sequences per cell. -void common_kv_cache_dump_view(const llama_kv_cache_view & view, int row_size = 80); - -// Dump the KV cache view showing individual sequences in each cell (long output). -void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_size = 40); - // // Embedding utils // @@ -662,3 +664,9 @@ const char * const LLM_KV_SPLIT_COUNT = "split.count"; const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"; } + +// +// training utils +// + +ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector & tokens, int64_t stride); diff --git a/common/json-partial.cpp b/common/json-partial.cpp new file mode 100644 index 000000000..d9d916998 --- /dev/null +++ b/common/json-partial.cpp @@ -0,0 +1,256 @@ +#include "json-partial.h" + +#include "log.h" + +#include + +#include + +using json = nlohmann::ordered_json; + +enum common_json_stack_element_type { + COMMON_JSON_STACK_ELEMENT_OBJECT, + COMMON_JSON_STACK_ELEMENT_KEY, + COMMON_JSON_STACK_ELEMENT_ARRAY, +}; + +struct common_json_stack_element { + common_json_stack_element_type type; + std::string key; +}; + +bool common_json_parse( + const std::string & input, + const std::string & healing_marker, + common_json & out) +{ + std::string::const_iterator it = input.begin(); + const auto end = input.end(); + return common_json_parse(it, end, healing_marker, out); +} + +bool common_json_parse( + std::string::const_iterator & it, + const std::string::const_iterator & end, + const std::string & healing_marker, + common_json & out) +{ + // // https://json.nlohmann.me/features/parsing/sax_interface/ + struct json_error_locator : public nlohmann::json_sax { + std::size_t position; + bool found_error; + std::string last_token; + std::string exception_message; + std::vector stack; + + json_error_locator() : position(0), found_error(false) {} + + bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT + this->position = position - 1; + this->found_error = true; + this->last_token = last_token; + this->exception_message = ex.what(); + return false; + } + void close_value() { + if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) { + stack.pop_back(); + } + } + bool null() override { // NOLINT + close_value(); + return true; + } + bool boolean(bool) override { // NOLINT + close_value(); + return true; + } + bool number_integer(number_integer_t) override { // NOLINT + close_value(); + return true; + } + bool number_unsigned(number_unsigned_t) override { // NOLINT + close_value(); + return true; + } + bool number_float(number_float_t, const string_t &) override { // NOLINT + close_value(); + return true; + } + bool string(string_t &) override { // NOLINT + close_value(); + return true; + } + bool binary(binary_t &) override { // NOLINT + close_value(); + return true; + } + bool start_object(std::size_t) override { // NOLINT + stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""}); + return true; + } + bool end_object() override { + GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT); + stack.pop_back(); + close_value(); + return true; + } + bool key(string_t & key) override { // NOLINT + stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key}); + return true; + } + bool start_array(std::size_t) override { // NOLINT + stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""}); + return true; + } + bool end_array() override { + GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY); + stack.pop_back(); + close_value(); + return true; + } + }; + json_error_locator err_loc; + auto start = it; + json::sax_parse(it, end, &err_loc); + + if (err_loc.found_error) { + it = start; + auto temptative_end = it + err_loc.position; + // LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str()); + + auto input = std::string(it, temptative_end); + try { + out.json = json::parse(input); + // out.json = json::parse(it, temptative_end); + it = temptative_end; + return true; + } catch (const std::exception & ex) { + // No, needs healing. + LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str()); + } + auto can_parse = [](const std::string & str) { + try { + auto _ = json::parse(str); // NOLINT + return true; + } catch (const std::exception &) { + return false; + } + }; + if (!healing_marker.empty() && !err_loc.stack.empty()) { + std::string str(it, temptative_end); + auto last_non_sp_pos = str.find_last_not_of(" \n\r\t"); + if (last_non_sp_pos == std::string::npos) { + throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location"); + } + auto last_non_sp_char = str[last_non_sp_pos]; + // Used to detect stops on a number, which may not be complete. + auto was_maybe_number = [&]() { + if (!str.empty() && std::isspace(str.back())) { + return false; + } + return std::isdigit(last_non_sp_char) || + last_non_sp_char == '.' || + last_non_sp_char == 'e' || + last_non_sp_char == 'E' || + last_non_sp_char == '-'; + }; + + std::string closing; + for (size_t i = err_loc.stack.size(); i > 0; i--) { + auto & el = err_loc.stack[i - 1]; + if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) { + closing += "}"; + } else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) { + closing += "]"; + } else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) { + throw std::runtime_error("Unexpected stack element type"); + } + } + + const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$"; + + if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) { + // We're inside an object value + if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) { + // Was about to create an object value + str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; + } else if (can_parse(str + ": 1" + closing)) { + str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing; + } else if (last_non_sp_char == '{' && can_parse(str + closing)) { + // Was about to create an object + str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing; + } else if (can_parse(str + "\"" + closing)) { + // Was inside an object value string + str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing; + } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) { + // Was inside an object value string after an escape + str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing; + } else { + // find last : + auto last_pos = str.find_last_of(':'); + if (last_pos == std::string::npos) { + throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location"); + } + // Cutting back to opening : for object value + str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; + } + } else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) { + if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) { + // Was about to create an array value + str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; + } else if (can_parse(str + "\"" + closing)) { + // Was inside an array value string + str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing; + } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) { + // Was inside an array value string after an escape + str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing; + } else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) { + // Had just finished a value + str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing; + } else { + auto last_pos = str.find_last_of("[,"); + if (last_pos == std::string::npos) { + throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location"); + } + // Cutting back to last [ or , for array value + str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; + } + } else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) { + if ((last_non_sp_char == '{' && can_parse(str + closing)) || + (last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) { + // Was about to create an object key+value + str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing; + } else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) { + // Was about to create an object key+value + str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing; + } else if (can_parse(str + "\": 1" + closing)) { + // Was inside an object key string + str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing; + } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) { + // Was inside an object key string after an escape + str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing; + } else { + auto last_pos = str.find_last_of(':'); + if (last_pos == std::string::npos) { + throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location"); + } + // fprintf(stderr, "Cutting back to last : for object key+value\n"); + str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; + } + } else { + throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location"); + } + // fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str()); + out.json = json::parse(str); + it = temptative_end; + return true; + } + // TODO: handle unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...) + // fprintf(stderr, "Closing: TODO\n"); + return false; + } + out.json = json::parse(it, end); + it = end; + return true; +} diff --git a/common/json-partial.h b/common/json-partial.h new file mode 100644 index 000000000..f63356dc4 --- /dev/null +++ b/common/json-partial.h @@ -0,0 +1,38 @@ +#pragma once + +#include + +// Healing marker (empty if the JSON was fully parsed / wasn't healed). +struct common_healing_marker { + // Raw marker. + std::string marker; + + // Cutting the `common_json.json.dump()` string at the (only) occurrence of this marker should yield the original partial JSON string (modulo spaces / if it had the same dump format). + std::string json_dump_marker; +}; + +// Represents a parsed JSON object, with its optional healing marker (a JSON dump fragment that can be used to find the position of healing in the JSON dump string) +struct common_json { + nlohmann::ordered_json json; + + common_healing_marker healing_marker; +}; + +// Parse the JSON string, healing (closing) any partial JSON if `healing_marker` is not empty. +// +// Healing completes partial JSON strings by adding a (possibly modified) healing marker, then whatever is needed to close the JSON. +// This allows to parse the resulting healed JSON string, yet be able to cut it again if needed at the healing marker. +// (this is used when parsing JSON outputs from the models, then crafting partial JSONs for the partial tool calls in OAI format). +// +// For instance, parsing `{` with a healing marker `foo` will produce a healed JSON `{"foo":1}`, w/ json_dump_marker = `"foo"` (which can be used to break the JSON again). +bool common_json_parse( + const std::string & input, + const std::string & healing_marker, + common_json & out); + +// Parse the JSON string (see overload above), but advancing an iterator to the end of the input when the (potentially partial) parsing succeeds. +bool common_json_parse( + std::string::const_iterator & it, + const std::string::const_iterator & end, + const std::string & healing_marker, + common_json & out); diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index 906798225..d38a74f95 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -1,8 +1,9 @@ #include "json-schema-to-grammar.h" #include "common.h" +#include + #include -#include #include #include #include @@ -16,6 +17,9 @@ using json = nlohmann::ordered_json; static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") { auto has_max = max_items != std::numeric_limits::max(); + if (max_items == 0) { + return ""; + } if (min_items == 0 && max_items == 1) { return item_rule + "?"; } diff --git a/common/json-schema-to-grammar.h b/common/json-schema-to-grammar.h index 4613f5d9f..362991b54 100644 --- a/common/json-schema-to-grammar.h +++ b/common/json-schema-to-grammar.h @@ -1,9 +1,9 @@ #pragma once -#include "ggml.h" -// Change JSON_ASSERT from assert() to GGML_ASSERT: -#define JSON_ASSERT GGML_ASSERT -#include "json.hpp" +#include + +#include +#include std::string json_schema_to_grammar(const nlohmann::ordered_json & schema, bool force_gbnf = false); diff --git a/common/llguidance.cpp b/common/llguidance.cpp index 8bff89ea4..adce620e4 100644 --- a/common/llguidance.cpp +++ b/common/llguidance.cpp @@ -189,6 +189,7 @@ static LlgTokenizer * llama_sampler_llg_new_tokenizer(const llama_vocab * vocab) /* .tokenize_fn = */ llama_sampler_llg_tokenize_fn, /* .use_approximate_greedy_tokenize_fn = */ false, /* .tokenize_user_data = */ vocab, + /* .slices = */ nullptr, }; char error_buffer[1024]; diff --git a/common/regex-partial.cpp b/common/regex-partial.cpp new file mode 100644 index 000000000..4bff6b663 --- /dev/null +++ b/common/regex-partial.cpp @@ -0,0 +1,204 @@ +#include "regex-partial.h" +#include "common.h" +#include +#include + +common_regex::common_regex(const std::string & pattern) : + pattern(pattern), + rx(pattern), + rx_reversed_partial(regex_to_reversed_partial_regex(pattern)) {} + +common_regex_match common_regex::search(const std::string & input, size_t pos, bool as_match) const { + std::smatch match; + if (pos > input.size()) { + throw std::runtime_error("Position out of bounds"); + } + auto start = input.begin() + pos; + auto found = as_match + ? std::regex_match(start, input.end(), match, rx) + : std::regex_search(start, input.end(), match, rx); + if (found) { + common_regex_match res; + res.type = COMMON_REGEX_MATCH_TYPE_FULL; + for (size_t i = 0; i < match.size(); ++i) { + auto begin = pos + match.position(i); + res.groups.emplace_back(begin, begin + match.length(i)); + } + return res; + } + std::match_results srmatch; + if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) { + auto group = srmatch[1].str(); + if (group.length() != 0) { + auto it = srmatch[1].second.base(); + // auto position = static_cast(std::distance(input.begin(), it)); + if ((!as_match) || it == input.begin()) { + common_regex_match res; + res.type = COMMON_REGEX_MATCH_TYPE_PARTIAL; + const size_t begin = std::distance(input.begin(), it); + const size_t end = input.size(); + if (begin == std::string::npos || end == std::string::npos || begin > end) { + throw std::runtime_error("Invalid range"); + } + res.groups.push_back({begin, end}); + return res; + } + } + } + return {}; +} + +/* + Transforms a regex pattern to a partial match pattern that operates on a reversed input string to find partial final matches of the original pattern. + + Ideally we'd like to use boost::match_partial (https://beta.boost.org/doc/libs/1_59_0/libs/regex/doc/html/boost_regex/partial_matches.html) + to see if a string ends with a partial regex match, but but it's not in std::regex yet. + Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input. + + - /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:(?:d)?c)?b)?a).* + - /a|b/ -> (a|b).* + - /a*?/ -> error, could match "" + - /a*b/ -> ((?:b)?a*+).* (final repetitions become eager) + - /.*?ab/ -> ((?:b)?a).* (merge .*) + - /a.*?b/ -> ((?:b)?.*?a).* (keep reluctant matches) + - /a(bc)d/ -> ((?:(?:d)?(?:(?:c)?b))?a).* + - /a(bc|de)/ -> ((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a).* + - /ab{2,4}c/ -> abbb?b?c -> ((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a).* + + The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern + (i.e. just where the final .* starts in the inverted pattern; all other groups are turned into non-capturing groups, and reluctant quantifiers are ignored) +*/ +std::string regex_to_reversed_partial_regex(const std::string & pattern) { + auto it = pattern.begin(); + const auto end = pattern.end(); + + std::function process = [&]() { + std::vector> alternatives(1); + std::vector * sequence = &alternatives.back(); + + while (it != end) { + if (*it == '[') { + auto start = it; + ++it; + while (it != end) { + if ((*it == '\\') && (++it != end)) { + ++it; + } else if ((it != end) && (*it == ']')) { + break; + } else { + ++it; + } + } + if (it == end) { + throw std::runtime_error("Unmatched '[' in pattern"); + } + ++it; + sequence->push_back(std::string(start, it)); + } else if (*it == '*' || *it == '?' || *it == '+') { + if (sequence->empty()) { + throw std::runtime_error("Quantifier without preceding element"); + } + sequence->back() += *it; + auto is_star = *it == '*'; + ++it; + if (is_star) { + if (*it == '?') { + ++it; + } + } + } else if (*it == '{') { + if (sequence->empty()) { + throw std::runtime_error("Repetition without preceding element"); + } + ++it; + auto start = it; + while (it != end && *it != '}') { + ++it; + } + if (it == end) { + throw std::runtime_error("Unmatched '{' in pattern"); + } + auto parts = string_split(std::string(start, it), ","); + ++it; + if (parts.size() > 2) { + throw std::runtime_error("Invalid repetition range in pattern"); + } + + auto parseOptInt = [&](const std::string & s, const std::optional & def = std::nullopt) -> std::optional { + if (s.empty()) { + return def; + } + return std::stoi(s); + }; + auto min = parseOptInt(parts[0], 0); + auto max = parts.size() == 1 ? min : parseOptInt(parts[1]); + if (min && max && *max < *min) { + throw std::runtime_error("Invalid repetition range in pattern"); + } + // Brutal but... let's repeat at least min times, then ? for the delta between min & max (or * for unbounded) + auto part = sequence->back(); + sequence->pop_back(); + for (int i = 0; i < *min; i++) { + sequence->push_back(part); + } + if (max) { + for (int i = *min; i < *max; i++) { + sequence->push_back(part + "?"); + } + } else { + sequence->push_back(part + "*"); + } + } else if (*it == '(') { + ++it; + if (it != end && *it == '?' && (it + 1 != end) && *(it + 1) == ':') { + it += 2; + } + auto sub = process(); + if (*it != ')') { + throw std::runtime_error("Unmatched '(' in pattern"); + } + ++it; + auto & part = sequence->emplace_back("(?:"); + part += sub; + part += ")"; + } else if (*it == ')') { + break; + } else if (*it == '|') { + ++it; + alternatives.emplace_back(); + sequence = &alternatives.back(); + } else if (*it == '\\' && (++it != end)) { + auto str = std::string("\\") + *it; + sequence->push_back(str); + ++it; + } else if (it != end) { + sequence->push_back(std::string(1, *it)); + ++it; + } + } + + // /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:d)?c)?b)?a).* + // if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group + // We'll do the outermost capturing group and final .* in the enclosing function. + std::vector res_alts; + for (const auto & parts : alternatives) { + auto & res = res_alts.emplace_back(); + for (size_t i = 0; i < parts.size() - 1; i++) { + res += "(?:"; + } + for (auto it = parts.rbegin(); it != parts.rend(); ++it) { + res += *it; + if (it != parts.rend() - 1) { + res += ")?"; + } + } + } + return string_join(res_alts, "|"); + }; + auto res = process(); + if (it != end) { + throw std::runtime_error("Unmatched '(' in pattern"); + } + + return "(" + res + ")[\\s\\S]*"; +} diff --git a/common/regex-partial.h b/common/regex-partial.h new file mode 100644 index 000000000..634cb4022 --- /dev/null +++ b/common/regex-partial.h @@ -0,0 +1,56 @@ +#pragma once + +#include +#include + +enum common_regex_match_type { + COMMON_REGEX_MATCH_TYPE_NONE, + COMMON_REGEX_MATCH_TYPE_PARTIAL, + COMMON_REGEX_MATCH_TYPE_FULL, +}; + +struct common_string_range { + size_t begin; + size_t end; + common_string_range(size_t begin, size_t end) : begin(begin), end(end) { + if (begin > end) { + throw std::runtime_error("Invalid range"); + } + } + // prevent default ctor + common_string_range() = delete; + bool empty() const { + return begin == end; + } + bool operator==(const common_string_range & other) const { + return begin == other.begin && end == other.end; + } +}; + +struct common_regex_match { + common_regex_match_type type = COMMON_REGEX_MATCH_TYPE_NONE; + std::vector groups; + + bool operator==(const common_regex_match & other) const { + return type == other.type && groups == other.groups; + } + bool operator!=(const common_regex_match & other) const { + return !(*this == other); + } +}; + +class common_regex { + std::string pattern; + std::regex rx; + std::regex rx_reversed_partial; + + public: + explicit common_regex(const std::string & pattern); + + common_regex_match search(const std::string & input, size_t pos, bool as_match = false) const; + + const std::string & str() const { return pattern; } +}; + +// For testing only (pretty print of failures). +std::string regex_to_reversed_partial_regex(const std::string & pattern); diff --git a/common/sampling.cpp b/common/sampling.cpp index 1735b6501..9c04d35fd 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -1,6 +1,7 @@ #include "sampling.h" #include "common.h" +#include "log.h" #include #include @@ -160,7 +161,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); #endif // LLAMA_USE_LLGUIDANCE } else { - std::vector patterns_at_start; + std::vector trigger_patterns; std::vector patterns_anywhere; std::vector trigger_tokens; for (const auto & trigger : params.grammar_triggers) { @@ -172,10 +173,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co break; } case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: - case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START: { - const auto & pattern = trigger.value; - (trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern); + patterns_anywhere.push_back(trigger.value); + break; + } + case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL: + { + trigger_patterns.push_back(trigger.value); break; } case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN: @@ -189,10 +193,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co } } - std::vector trigger_patterns; - if (!patterns_at_start.empty()) { - trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*"); - } if (!patterns_anywhere.empty()) { trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*"); } @@ -229,51 +229,48 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co params.logit_bias.data())); if (params.mirostat == 0) { - if (params.top_n_sigma >= 0) { - llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); - llama_sampler_chain_add(result->chain, llama_sampler_init_temp (params.temp)); - llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma)); - } else { - for (const auto & cnstr : params.samplers) { - switch (cnstr) { - case COMMON_SAMPLER_TYPE_DRY: - { - std::vector c_breakers; - c_breakers.reserve(params.dry_sequence_breakers.size()); - for (const auto & str : params.dry_sequence_breakers) { - c_breakers.push_back(str.c_str()); - } - - llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); + for (const auto & cnstr : params.samplers) { + switch (cnstr) { + case COMMON_SAMPLER_TYPE_DRY: + { + std::vector c_breakers; + c_breakers.reserve(params.dry_sequence_breakers.size()); + for (const auto & str : params.dry_sequence_breakers) { + c_breakers.push_back(str.c_str()); } - break; - case COMMON_SAMPLER_TYPE_TOP_K: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); - break; - case COMMON_SAMPLER_TYPE_TOP_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); - break; - case COMMON_SAMPLER_TYPE_MIN_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); - break; - case COMMON_SAMPLER_TYPE_XTC: - llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); - break; - case COMMON_SAMPLER_TYPE_TYPICAL_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); - break; - case COMMON_SAMPLER_TYPE_TEMPERATURE: - llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); - break; - case COMMON_SAMPLER_TYPE_INFILL: - llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab)); - break; - case COMMON_SAMPLER_TYPE_PENALTIES: - llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); - break; - default: - GGML_ASSERT(false && "unknown sampler type"); - } + + llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); + } + break; + case COMMON_SAMPLER_TYPE_TOP_K: + llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); + break; + case COMMON_SAMPLER_TYPE_TOP_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: + llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma)); + break; + case COMMON_SAMPLER_TYPE_MIN_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_XTC: + llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); + break; + case COMMON_SAMPLER_TYPE_TYPICAL_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_TEMPERATURE: + llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + break; + case COMMON_SAMPLER_TYPE_INFILL: + llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab)); + break; + case COMMON_SAMPLER_TYPE_PENALTIES: + llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); + break; + default: + GGML_ASSERT(false && "unknown sampler type"); } } llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); @@ -475,6 +472,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) { case COMMON_SAMPLER_TYPE_TOP_K: return 'k'; case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y'; case COMMON_SAMPLER_TYPE_TOP_P: return 'p'; + case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's'; case COMMON_SAMPLER_TYPE_MIN_P: return 'm'; case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't'; case COMMON_SAMPLER_TYPE_XTC: return 'x'; @@ -490,6 +488,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) { case COMMON_SAMPLER_TYPE_TOP_K: return "top_k"; case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p"; case COMMON_SAMPLER_TYPE_TOP_P: return "top_p"; + case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma"; case COMMON_SAMPLER_TYPE_MIN_P: return "min_p"; case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature"; case COMMON_SAMPLER_TYPE_XTC: return "xtc"; @@ -504,6 +503,7 @@ std::vector common_sampler_types_from_names(const std::vect { "dry", COMMON_SAMPLER_TYPE_DRY }, { "top_k", COMMON_SAMPLER_TYPE_TOP_K }, { "top_p", COMMON_SAMPLER_TYPE_TOP_P }, + { "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, { "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "min_p", COMMON_SAMPLER_TYPE_MIN_P }, { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE }, @@ -517,6 +517,7 @@ std::vector common_sampler_types_from_names(const std::vect std::unordered_map sampler_alt_name_map { { "top-k", COMMON_SAMPLER_TYPE_TOP_K }, { "top-p", COMMON_SAMPLER_TYPE_TOP_P }, + { "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, { "nucleus", COMMON_SAMPLER_TYPE_TOP_P }, { "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "typical", COMMON_SAMPLER_TYPE_TYPICAL_P }, @@ -533,14 +534,16 @@ std::vector common_sampler_types_from_names(const std::vect auto sampler = sampler_canonical_name_map.find(name); if (sampler != sampler_canonical_name_map.end()) { samplers.push_back(sampler->second); - } else { - if (allow_alt_names) { - sampler = sampler_alt_name_map.find(name); - if (sampler != sampler_alt_name_map.end()) { - samplers.push_back(sampler->second); - } + continue; + } + if (allow_alt_names) { + sampler = sampler_alt_name_map.find(name); + if (sampler != sampler_alt_name_map.end()) { + samplers.push_back(sampler->second); + continue; } } + LOG_WRN("%s: unable to match sampler by name '%s'\n", __func__, name.c_str()); } return samplers; @@ -552,6 +555,7 @@ std::vector common_sampler_types_from_chars(const std::stri { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }, @@ -566,6 +570,8 @@ std::vector common_sampler_types_from_chars(const std::stri const auto sampler = sampler_name_map.find(c); if (sampler != sampler_name_map.end()) { samplers.push_back(sampler->second); + } else { + LOG_WRN("%s: unable to match sampler by char '%c'\n", __func__, c); } } diff --git a/common/speculative.cpp b/common/speculative.cpp index ccad70fa9..843bd1ddb 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -144,6 +144,8 @@ llama_tokens common_speculative_gen_draft( auto & smpl = spec->smpl; auto & prompt = spec->prompt; + auto * mem = llama_get_memory(ctx); + int reuse_i = 0; int reuse_n = 0; @@ -173,7 +175,7 @@ llama_tokens common_speculative_gen_draft( result.reserve(params.n_draft); if (reuse_n == 0) { - llama_kv_self_clear(ctx); + llama_memory_clear(mem, false); prompt.clear(); } else { @@ -192,14 +194,14 @@ llama_tokens common_speculative_gen_draft( } if (reuse_i > 0) { - llama_kv_self_seq_rm (ctx, 0, 0, reuse_i); - llama_kv_self_seq_add(ctx, 0, reuse_i, -1, -reuse_i); + llama_memory_seq_rm (mem, 0, 0, reuse_i); + llama_memory_seq_add(mem, 0, reuse_i, -1, -reuse_i); prompt.erase(prompt.begin(), prompt.begin() + reuse_i); } if (reuse_n < (int) prompt.size()) { - llama_kv_self_seq_rm (ctx, 0, reuse_n, -1); + llama_memory_seq_rm (mem, 0, reuse_n, -1); prompt.erase(prompt.begin() + reuse_n, prompt.end()); } diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2bf97475f..b754dd815 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -16,6 +16,7 @@ from pathlib import Path from hashlib import sha256 from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast from itertools import chain +from transformers import AutoConfig import math import numpy as np @@ -42,11 +43,19 @@ class SentencePieceTokenTypes(IntEnum): BYTE = 6 -AnyModel = TypeVar("AnyModel", bound="type[Model]") +class ModelType(IntEnum): + TEXT = 1 + MMPROJ = 2 -class Model: - _model_classes: dict[str, type[Model]] = {} +AnyModel = TypeVar("AnyModel", bound="type[ModelBase]") + + +class ModelBase: + _model_classes: dict[ModelType, dict[str, type[ModelBase]]] = { + ModelType.TEXT: {}, + ModelType.MMPROJ: {}, + } dir_model: Path ftype: gguf.LlamaFileType @@ -58,8 +67,6 @@ class Model: part_names: list[str] is_safetensors: bool hparams: dict[str, Any] - block_count: int - tensor_map: gguf.TensorNameMap tensor_names: set[str] | None gguf_writer: gguf.GGUFWriter model_name: str | None @@ -70,12 +77,18 @@ class Model: # subclasses should define this! model_arch: gguf.MODEL_ARCH - def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False, + # subclasses should initialize this! + block_count: int + tensor_map: gguf.TensorNameMap + + def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False, use_temp_file: bool = False, eager: bool = False, metadata_override: Path | None = None, model_name: str | None = None, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None): - if type(self) is Model: + if type(self) is ModelBase or \ + type(self) is TextModel or \ + type(self) is MmprojModel: raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") self.dir_model = dir_model @@ -98,13 +111,11 @@ class Model: self.get_tensors = get_remote_tensors else: - self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors") + self.part_names = ModelBase.get_model_part_names(self.dir_model, "model", ".safetensors") self.is_safetensors = len(self.part_names) > 0 if not self.is_safetensors: - self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin") - self.hparams = Model.load_hparams(self.dir_model) if hparams is None else hparams - self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"]) - self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin") + self.hparams = ModelBase.load_hparams(self.dir_model) if hparams is None else hparams self.tensor_names = None self.metadata_override = metadata_override self.model_name = model_name @@ -126,11 +137,10 @@ class Model: split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard) @classmethod - def __init_subclass__(cls): - # can't use an abstract property, because overriding it without type errors - # would require using decorated functions instead of simply defining the property - if "model_arch" not in cls.__dict__: - raise TypeError(f"Missing property 'model_arch' for {cls.__name__!r}") + def add_prefix_to_filename(cls, path: Path, prefix: str) -> Path: + stem, suffix = path.stem, path.suffix + new_name = f"{prefix}{stem}{suffix}" + return path.with_name(new_name) def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any: key = next((k for k in keys if k in self.hparams), None) @@ -140,9 +150,6 @@ class Model: return None raise KeyError(f"could not find any of: {keys}") - def set_vocab(self): - self._set_vocab_gpt2() - def get_tensors(self) -> Iterator[tuple[str, Tensor]]: tensor_names_from_parts: set[str] = set() @@ -230,50 +237,7 @@ class Model: return new_name def set_gguf_parameters(self): - self.gguf_writer.add_block_count(self.block_count) - - if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None: - self.gguf_writer.add_context_length(n_ctx) - logger.info(f"gguf: context length = {n_ctx}") - - if (n_embd := self.find_hparam(["hidden_size", "n_embd"], optional=True)) is not None: - self.gguf_writer.add_embedding_length(n_embd) - logger.info(f"gguf: embedding length = {n_embd}") - - if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None: - self.gguf_writer.add_feed_forward_length(n_ff) - logger.info(f"gguf: feed forward length = {n_ff}") - - if (n_head := self.find_hparam(["num_attention_heads", "n_head"], optional=True)) is not None: - self.gguf_writer.add_head_count(n_head) - logger.info(f"gguf: head count = {n_head}") - - if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None: - self.gguf_writer.add_head_count_kv(n_head_kv) - logger.info(f"gguf: key-value head count = {n_head_kv}") - - if (rope_theta := self.hparams.get("rope_theta")) is not None: - self.gguf_writer.add_rope_freq_base(rope_theta) - logger.info(f"gguf: rope theta = {rope_theta}") - if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None: - self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps) - logger.info(f"gguf: rms norm epsilon = {f_rms_eps}") - if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None: - self.gguf_writer.add_layer_norm_eps(f_norm_eps) - logger.info(f"gguf: layer norm epsilon = {f_norm_eps}") - if (n_experts := self.hparams.get("num_local_experts")) is not None: - self.gguf_writer.add_expert_count(n_experts) - logger.info(f"gguf: expert count = {n_experts}") - if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None: - self.gguf_writer.add_expert_used_count(n_experts_used) - logger.info(f"gguf: experts used count = {n_experts_used}") - - if (head_dim := self.hparams.get("head_dim")) is not None: - self.gguf_writer.add_key_length(head_dim) - self.gguf_writer.add_value_length(head_dim) - - self.gguf_writer.add_file_type(self.ftype) - logger.info(f"gguf: file type = {self.ftype}") + raise NotImplementedError("set_gguf_parameters() must be implemented in subclasses") def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused @@ -344,6 +308,8 @@ class Model: gguf.MODEL_TENSOR.TIME_MIX_LERP_FUSED, gguf.MODEL_TENSOR.POSNET_NORM1, gguf.MODEL_TENSOR.POSNET_NORM2, + gguf.MODEL_TENSOR.V_ENC_EMBD_POS, + gguf.MODEL_TENSOR.A_ENC_EMBD_POS, ) ) or not new_name.endswith(".weight") @@ -419,6 +385,113 @@ class Model: if self.metadata.size_label is None and total_params > 0: self.metadata.size_label = gguf.size_label(total_params, shared_params, expert_params, expert_count) + self.set_type() + + logger.info("Set meta model") + self.metadata.set_gguf_meta_model(self.gguf_writer) + + logger.info("Set model parameters") + self.set_gguf_parameters() + + logger.info("Set model quantization version") + self.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION) + + def write_vocab(self): + raise NotImplementedError("write_vocab() must be implemented in subclasses") + + def write(self): + self.prepare_tensors() + self.prepare_metadata(vocab_only=False) + self.gguf_writer.write_header_to_file(path=self.fname_out) + self.gguf_writer.write_kv_data_to_file() + self.gguf_writer.write_tensors_to_file(progress=True) + self.gguf_writer.close() + + @staticmethod + def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]: + part_names: list[str] = [] + for filename in os.listdir(dir_model): + if filename.startswith(prefix) and filename.endswith(suffix): + part_names.append(filename) + + part_names.sort() + + return part_names + + @staticmethod + def load_hparams(dir_model: Path): + try: + # for security reason, we don't allow loading remote code by default + # if a model need remote code, we will fallback to config.json + config = AutoConfig.from_pretrained(dir_model, trust_remote_code=False).to_dict() + except Exception as e: + logger.warning(f"Failed to load model config from {dir_model}: {e}") + logger.warning("Trying to load config.json instead") + with open(dir_model / "config.json", "r", encoding="utf-8") as f: + config = json.load(f) + if "llm_config" in config: + # rename for InternVL + config["text_config"] = config["llm_config"] + if "thinker_config" in config: + # rename for Qwen2.5-Omni + config["text_config"] = config["thinker_config"]["text_config"] + return config + + @classmethod + def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]: + assert names + + def func(modelcls: AnyModel) -> AnyModel: + model_type = ModelType.MMPROJ if modelcls.model_arch == gguf.MODEL_ARCH.MMPROJ else ModelType.TEXT + for name in names: + cls._model_classes[model_type][name] = modelcls + return modelcls + return func + + @classmethod + def print_registered_models(cls): + for model_type, model_classes in cls._model_classes.items(): + logger.error(f"{model_type.name} models:") + for name in sorted(model_classes.keys()): + logger.error(f" - {name}") + + @classmethod + def from_model_architecture(cls, arch: str, model_type = ModelType.TEXT) -> type[ModelBase]: + try: + return cls._model_classes[model_type][arch] + except KeyError: + raise NotImplementedError(f'Architecture {arch!r} not supported!') from None + + +class TextModel(ModelBase): + model_type = ModelType.TEXT + hf_arch: str + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.hf_arch = get_model_architecture(self.hparams, self.model_type) + + if "text_config" in self.hparams: + # move the text_config to the root level + self.hparams = {**self.hparams, **self.hparams["text_config"]} + + self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"]) + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + @classmethod + def __init_subclass__(cls): + # can't use an abstract property, because overriding it without type errors + # would require using decorated functions instead of simply defining the property + if "model_arch" not in cls.__dict__: + raise TypeError(f"Missing property 'model_arch' for {cls.__name__!r}") + + def set_vocab(self): + self._set_vocab_gpt2() + + def prepare_metadata(self, vocab_only: bool): + super().prepare_metadata(vocab_only=vocab_only) + + total_params = self.gguf_writer.get_total_parameter_count()[0] # Extract the encoding scheme from the file type name. e.g. 'gguf.LlamaFileType.MOSTLY_Q8_0' --> 'Q8_0' output_type: str = self.ftype.name.partition("_")[2] @@ -440,27 +513,54 @@ class Model: # Process templated file name with the output ftype, useful with the "auto" ftype self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type) - self.set_type() - - logger.info("Set meta model") - self.metadata.set_gguf_meta_model(self.gguf_writer) - - logger.info("Set model parameters") - self.set_gguf_parameters() - logger.info("Set model tokenizer") self.set_vocab() - logger.info("Set model quantization version") - self.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION) + def set_gguf_parameters(self): + self.gguf_writer.add_block_count(self.block_count) - def write(self): - self.prepare_tensors() - self.prepare_metadata(vocab_only=False) - self.gguf_writer.write_header_to_file(path=self.fname_out) - self.gguf_writer.write_kv_data_to_file() - self.gguf_writer.write_tensors_to_file(progress=True) - self.gguf_writer.close() + if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx", "n_positions", "max_length"], optional=True)) is not None: + self.gguf_writer.add_context_length(n_ctx) + logger.info(f"gguf: context length = {n_ctx}") + + if (n_embd := self.find_hparam(["hidden_size", "n_embd", "dim"], optional=True)) is not None: + self.gguf_writer.add_embedding_length(n_embd) + logger.info(f"gguf: embedding length = {n_embd}") + + if (n_ff := self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"], optional=True)) is not None: + self.gguf_writer.add_feed_forward_length(n_ff) + logger.info(f"gguf: feed forward length = {n_ff}") + + if (n_head := self.find_hparam(["num_attention_heads", "n_head", "n_heads"], optional=True)) is not None: + self.gguf_writer.add_head_count(n_head) + logger.info(f"gguf: head count = {n_head}") + + if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None: + self.gguf_writer.add_head_count_kv(n_head_kv) + logger.info(f"gguf: key-value head count = {n_head_kv}") + + if (rope_theta := self.hparams.get("rope_theta")) is not None: + self.gguf_writer.add_rope_freq_base(rope_theta) + logger.info(f"gguf: rope theta = {rope_theta}") + if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None: + self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps) + logger.info(f"gguf: rms norm epsilon = {f_rms_eps}") + if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None: + self.gguf_writer.add_layer_norm_eps(f_norm_eps) + logger.info(f"gguf: layer norm epsilon = {f_norm_eps}") + if (n_experts := self.hparams.get("num_local_experts")) is not None: + self.gguf_writer.add_expert_count(n_experts) + logger.info(f"gguf: expert count = {n_experts}") + if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None: + self.gguf_writer.add_expert_used_count(n_experts_used) + logger.info(f"gguf: experts used count = {n_experts_used}") + + if (head_dim := self.hparams.get("head_dim")) is not None: + self.gguf_writer.add_key_length(head_dim) + self.gguf_writer.add_value_length(head_dim) + + self.gguf_writer.add_file_type(self.ftype) + logger.info(f"gguf: file type = {self.ftype}") def write_vocab(self): if len(self.gguf_writer.tensors) != 1: @@ -471,44 +571,6 @@ class Model: self.gguf_writer.write_kv_data_to_file() self.gguf_writer.close() - @staticmethod - def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]: - part_names: list[str] = [] - for filename in os.listdir(dir_model): - if filename.startswith(prefix) and filename.endswith(suffix): - part_names.append(filename) - - part_names.sort() - - return part_names - - @staticmethod - def load_hparams(dir_model: Path): - with open(dir_model / "config.json", "r", encoding="utf-8") as f: - return json.load(f) - - @classmethod - def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]: - assert names - - def func(modelcls: AnyModel) -> AnyModel: - for name in names: - cls._model_classes[name] = modelcls - return modelcls - return func - - @classmethod - def print_registered_models(cls): - for name in sorted(cls._model_classes.keys()): - logger.error(f"- {name}") - - @classmethod - def from_model_architecture(cls, arch: str) -> type[Model]: - try: - return cls._model_classes[arch] - except KeyError: - raise NotImplementedError(f'Architecture {arch!r} not supported!') from None - def does_token_look_special(self, token: str | bytes) -> bool: if isinstance(token, (bytes, bytearray)): token_text = token.decode(encoding="utf-8") @@ -612,12 +674,12 @@ class Model: if chkhsh == "8aeee3860c56296a157a1fe2fad249ec40aa59b1bb5709f4ade11c4e6fe652ed": # ref: https://huggingface.co/tiiuae/falcon-7b res = "falcon" - if chkhsh == "9d032fcbd5501f4a38150912590928bfb36091efb5df11b8e2124b0390e3fb1e": - # ref: https://huggingface.co/tiiuae/Falcon3-7B-Base - res = "falcon3" if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f": # ref: https://huggingface.co/BAAI/bge-small-en-v1.5 res = "bert-bge" + if chkhsh == "9d032fcbd5501f4a38150912590928bfb36091efb5df11b8e2124b0390e3fb1e": + # ref: https://huggingface.co/tiiuae/Falcon3-7B-Base + res = "falcon3" if chkhsh == "8e62295832751ca1e8f92f2226f403dea30dc5165e448b5bfa05af5340c64ec7": # ref: https://huggingface.co/BAAI/bge-large-zh-v1.5 res = "bert-bge-large" @@ -669,9 +731,6 @@ class Model: if chkhsh == "7967bfa498ade6b757b064f31e964dddbb80f8f9a4d68d4ba7998fcf281c531a": # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-code res = "jina-v2-code" - if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b" or chkhsh == "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516": - # ref: https://huggingface.co/THUDM/glm-4-9b-chat - res = "chatglm-bpe" if chkhsh == "7fc505bd3104ca1083b150b17d088b59534ede9bde81f0dd2090967d7fe52cee": # ref: https://huggingface.co/LumiOpen/Viking-7B res = "viking" @@ -702,9 +761,6 @@ class Model: if chkhsh == "60824e3c0d9401f89943cbb2fff727f0e2d4c545ba4df2d6e4f09a6db0f5b450": # ref: https://huggingface.co/facebook/chameleon-7b res = "chameleon" - if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35": - # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0 - res = "minerva-7b" if chkhsh == "8b5a93ed704057481f240da0be7e7dca721d7f8f4755263b6807227a2cbeae65": # ref: https://huggingface.co/sentence-transformers/stsb-roberta-base res = "roberta-bpe" @@ -735,9 +791,24 @@ class Model: if chkhsh == "d353350c764d8c3b39c763113960e4fb4919bea5fbf208a0e3b22e8469dc7406": # ref: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct res = "llama4" + if chkhsh == "0e9433cbbb161f89e264eb32e8e64bfe69e834973ffca5d41d3948a604a3e2a3": + # ref: https://huggingface.co/mistral-community/pixtral-12b + res = "pixtral" + if chkhsh == "d5f1dd6f980fec569fb218a81a7658ac45fc56b38c5a0adeb1c232fbe04ef5ec": + # ref: https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base + res = "seed-coder" + if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b": + # ref: https://huggingface.co/THUDM/glm-4-9b-chat + res = "chatglm-bpe" + if chkhsh == "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516": + # ref: https://huggingface.co/THUDM/glm-4-9b-chat + res = "chatglm-bpe" if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2": # ref: https://huggingface.co/THUDM/glm-4-9b-hf res = "glm4" + if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35": + # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0 + res = "minerva-7b" if res is None: logger.warning("\n") @@ -976,6 +1047,10 @@ class Model: special_vocab.chat_template = "rwkv-world" # hack: Add '\n\n' as the EOT token to make it chat normally special_vocab._set_special_token("eot", 261) + # hack: Override these as they have already been set (incorrectly) + special_vocab.special_token_ids["bos"] = 0 + special_vocab.special_token_ids["eos"] = 0 + special_vocab.add_to_gguf(self.gguf_writer) def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab_size: int): @@ -1023,9 +1098,147 @@ class Model: if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.ADD_EOS)) is not None: self.gguf_writer.add_add_eos_token(field.parts[-1].tolist()[0]) + def _try_set_pooling_type(self) -> None: + # get pooling path + pooling_path = None + module_path = self.dir_model / "modules.json" + if module_path.is_file(): + with open(module_path, encoding="utf-8") as f: + modules = json.load(f) + for mod in modules: + if mod["type"] == "sentence_transformers.models.Pooling": + pooling_path = mod["path"] + break -@Model.register("GPTNeoXForCausalLM") -class GPTNeoXModel(Model): + # get pooling type + if pooling_path is not None: + with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f: + pooling = json.load(f) + if pooling["pooling_mode_mean_tokens"]: + pooling_type = gguf.PoolingType.MEAN + elif pooling["pooling_mode_cls_token"]: + pooling_type = gguf.PoolingType.CLS + elif pooling["pooling_mode_lasttoken"]: + pooling_type = gguf.PoolingType.LAST + else: + raise NotImplementedError("Only MEAN, CLS, and LAST pooling types supported") + self.gguf_writer.add_pooling_type(pooling_type) + + +class MmprojModel(ModelBase): + model_type = ModelType.MMPROJ + model_arch = gguf.MODEL_ARCH.MMPROJ + preprocessor_config: dict[str, Any] + global_config: dict[str, Any] + + n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"] + + has_vision_encoder: bool = True # by default + has_audio_encoder: bool = False + + # for models having multiple encoders, we need to separate their hparams + hparams_vision: dict[str, Any] | None = None + hparams_audio: dict[str, Any] | None = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if self.model_arch != gguf.MODEL_ARCH.MMPROJ: + raise TypeError("MmprojModel must be subclassed with model_arch = gguf.MODEL_ARCH.MMPROJ") + + # get n_embd of the text model + if "text_config" not in self.hparams: + self.hparams["text_config"] = {} + if "audio_config" not in self.hparams: + self.hparams["audio_config"] = {} + text_config = {**self.hparams, **self.hparams["text_config"]} + self.n_embd_text = text_config.get("hidden_size", text_config.get("n_embd", 0)) + assert self.n_embd_text > 0, "n_embd not found in hparams" + + # move vision config to the top level, while preserving the original hparams in global_config + import copy + self.global_config = copy.deepcopy(self.hparams) + self.hparams_vision = self.get_vision_config() + self.hparams_audio = self.get_audio_config() + + if self.hparams_vision is None and self.hparams_audio is None: + raise ValueError("vision_config / audio_config not found in hparams") + + # for compat with vision-only models + self.hparams = self.hparams_vision or self.hparams_audio or self.hparams + + # TODO @ngxson : this is a hack to support both vision and audio encoders + have_multiple_encoders = self.has_audio_encoder and self.has_vision_encoder + self.block_count = 128 if have_multiple_encoders else self.find_hparam(self.n_block_keys, True) + self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count) + + # load preprocessor config + with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f: + self.preprocessor_config = json.load(f) + + def get_vision_config(self) -> dict[str, Any] | None: + return self.global_config.get("vision_config") + + def get_audio_config(self) -> dict[str, Any] | None: + return self.global_config.get("audio_config") + + def set_type(self): + self.gguf_writer.add_type(gguf.GGUFType.MMPROJ) + + def set_gguf_parameters(self): + self.gguf_writer.add_file_type(self.ftype) + + if self.has_vision_encoder: + self.gguf_writer.add_clip_has_vision_encoder(True) + self.gguf_writer.add_vision_projection_dim(self.n_embd_text) + + # vision config + self.gguf_writer.add_vision_image_size(self.find_vparam(["image_size"])) + self.gguf_writer.add_vision_patch_size(self.find_vparam(["patch_size"])) + self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"])) + self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"])) + self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys)) + self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads"])) + + # preprocessor config + self.gguf_writer.add_vision_image_mean(self.preprocessor_config["image_mean"]) + self.gguf_writer.add_vision_image_std(self.preprocessor_config["image_std"]) + + if self.has_audio_encoder: + self.gguf_writer.add_clip_has_audio_encoder(True) + self.gguf_writer.add_audio_projection_dim(self.n_embd_text) + + # audio config + self.gguf_writer.add_audio_embedding_length(self.find_aparam(["hidden_size"])) + self.gguf_writer.add_audio_feed_forward_length(self.find_aparam(["intermediate_size"])) + self.gguf_writer.add_audio_block_count(self.find_aparam(self.n_block_keys)) + self.gguf_writer.add_audio_head_count(self.find_aparam(["num_attention_heads"])) + + if not self.has_vision_encoder and not self.has_audio_encoder: + raise ValueError("MmprojModel must have either vision or audio encoder") + + def write_vocab(self): + raise ValueError("MmprojModel does not support vocab writing") + + def find_vparam(self, keys: Iterable[str], optional: bool = False) -> Any: + assert self.hparams_vision is not None + return self._find_param(self.hparams_vision, keys, optional) + + def find_aparam(self, keys: Iterable[str], optional: bool = False) -> Any: + assert self.hparams_audio is not None + return self._find_param(self.hparams_audio, keys, optional) + + def _find_param(self, obj: dict[str, Any], keys: Iterable[str], optional: bool = False) -> Any: + key = next((k for k in keys if k in obj), None) + if key is not None: + return obj[key] + if optional: + return None + raise KeyError(f"could not find any of: {keys}") + + +@ModelBase.register("GPTNeoXForCausalLM") +class GPTNeoXModel(TextModel): model_arch = gguf.MODEL_ARCH.GPTNEOX def set_gguf_parameters(self): @@ -1081,8 +1294,8 @@ class GPTNeoXModel(Model): return tensors -@Model.register("BloomForCausalLM", "BloomModel") -class BloomModel(Model): +@ModelBase.register("BloomForCausalLM", "BloomModel") +class BloomModel(TextModel): model_arch = gguf.MODEL_ARCH.BLOOM def set_gguf_parameters(self): @@ -1138,8 +1351,8 @@ class BloomModel(Model): return tensors -@Model.register("MPTForCausalLM") -class MPTModel(Model): +@ModelBase.register("MPTForCausalLM") +class MPTModel(TextModel): model_arch = gguf.MODEL_ARCH.MPT def set_vocab(self): @@ -1182,8 +1395,8 @@ class MPTModel(Model): return [(new_name, data_torch)] -@Model.register("OrionForCausalLM") -class OrionModel(Model): +@ModelBase.register("OrionForCausalLM") +class OrionModel(TextModel): model_arch = gguf.MODEL_ARCH.ORION def set_vocab(self): @@ -1217,8 +1430,8 @@ class OrionModel(Model): self.gguf_writer.add_layer_norm_eps(self.hparams["rms_norm_eps"]) -@Model.register("BaichuanForCausalLM", "BaiChuanForCausalLM") -class BaichuanModel(Model): +@ModelBase.register("BaichuanForCausalLM", "BaiChuanForCausalLM") +class BaichuanModel(TextModel): model_arch = gguf.MODEL_ARCH.BAICHUAN def set_vocab(self): @@ -1250,10 +1463,10 @@ class BaichuanModel(Model): self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) self.gguf_writer.add_file_type(self.ftype) - if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: - if self.hparams["rope_scaling"].get("type") == "linear": - self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) - self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: head_count = self.hparams["num_attention_heads"] @@ -1297,8 +1510,8 @@ class BaichuanModel(Model): return weights[r * n_part:r * n_part + r, ...] -@Model.register("XverseForCausalLM") -class XverseModel(Model): +@ModelBase.register("XverseForCausalLM") +class XverseModel(TextModel): model_arch = gguf.MODEL_ARCH.XVERSE def set_vocab(self): @@ -1374,10 +1587,10 @@ class XverseModel(Model): self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) self.gguf_writer.add_file_type(self.ftype) - if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: - if self.hparams["rope_scaling"].get("type") == "linear": - self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) - self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused @@ -1404,8 +1617,8 @@ class XverseModel(Model): ) -@Model.register("FalconForCausalLM", "RWForCausalLM") -class FalconModel(Model): +@ModelBase.register("FalconForCausalLM", "RWForCausalLM") +class FalconModel(TextModel): model_arch = gguf.MODEL_ARCH.FALCON def set_gguf_parameters(self): @@ -1458,8 +1671,8 @@ class FalconModel(Model): return [(self.map_tensor_name(name), data_torch)] -@Model.register("GPTBigCodeForCausalLM") -class StarCoderModel(Model): +@ModelBase.register("GPTBigCodeForCausalLM") +class StarCoderModel(TextModel): model_arch = gguf.MODEL_ARCH.STARCODER def set_gguf_parameters(self): @@ -1475,8 +1688,8 @@ class StarCoderModel(Model): self.gguf_writer.add_file_type(self.ftype) -@Model.register("GPTRefactForCausalLM") -class RefactModel(Model): +@ModelBase.register("GPTRefactForCausalLM") +class RefactModel(TextModel): model_arch = gguf.MODEL_ARCH.REFACT def set_vocab(self): @@ -1539,8 +1752,8 @@ class RefactModel(Model): return tensors -@Model.register("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM") -class StableLMModel(Model): +@ModelBase.register("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM") +class StableLMModel(TextModel): model_arch = gguf.MODEL_ARCH.STABLELM def set_vocab(self): @@ -1629,11 +1842,24 @@ class StableLMModel(Model): raise ValueError(f"Unprocessed norms: {norms}") -@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM") -class LlamaModel(Model): +@ModelBase.register( + "LLaMAForCausalLM", + "LlamaForCausalLM", + "MistralForCausalLM", + "MixtralForCausalLM", + "VLlama3ForCausalLM", + "LlavaForConditionalGeneration", + "LlamaModel") +class LlamaModel(TextModel): model_arch = gguf.MODEL_ARCH.LLAMA undo_permute = True + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # fix for SmolVLM2, missing `num_attention_heads` in config.json + if self.hf_arch == "VLlama3ForCausalLM": + self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32) + def set_vocab(self): try: self._set_vocab_sentencepiece() @@ -1672,16 +1898,14 @@ class LlamaModel(Model): hparams = self.hparams self.gguf_writer.add_vocab_size(hparams["vocab_size"]) - if "head_dim" in hparams: - rope_dim = hparams["head_dim"] - else: + if (rope_dim := hparams.get("head_dim")) is None: rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] self.gguf_writer.add_rope_dimension_count(rope_dim) - if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: - if self.hparams["rope_scaling"].get("type") == "linear": - self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) - self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) @staticmethod def permute(weights: Tensor, n_head: int, n_head_kv: int | None): @@ -1696,6 +1920,19 @@ class LlamaModel(Model): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") + is_vision_tensor = "vision_tower" in name \ + or "vision_model" in name \ + or "model.connector" in name \ + or "multi_modal_projector" in name + + if is_vision_tensor: + return [] # skip vision tensors + elif self.hf_arch == "LlamaModel": + name = "model." + name + elif name.startswith("model.text_model"): + name = name.replace("text_model.", "") # for SmolVLM + elif name.startswith("language_model."): + name = name.replace("language_model.", "") # for the rest if self.undo_permute: if name.endswith(("q_proj.weight", "q_proj.bias")): @@ -1743,7 +1980,8 @@ class LlamaModel(Model): if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): if rope_scaling.get("rope_type", '').lower() == "llama3": base = self.hparams.get("rope_theta", 10000.0) - dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + if (dim := self.hparams.get("head_dim")) is None: + dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) factor = rope_scaling.get("factor", 8.0) @@ -1778,23 +2016,129 @@ class LlamaModel(Model): raise ValueError(f"Unprocessed experts: {experts}") -@Model.register("Llama4ForConditionalGeneration") +@ModelBase.register("ArceeForCausalLM") +class ArceeModel(LlamaModel): + model_arch = gguf.MODEL_ARCH.ARCEE + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self._try_set_pooling_type() + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + + +@ModelBase.register( + "LlavaForConditionalGeneration", # pixtral + "Mistral3ForConditionalGeneration", # mistral small 3.1 +) +class LlavaVisionModel(MmprojModel): + img_break_tok_id = -1 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.hparams["model_type"] == "pixtral": + # layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py + self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5) + self.img_break_tok_id = self.get_token_id("[IMG_BREAK]") + logger.info(f"Image break token id: {self.img_break_tok_id}") + else: + raise ValueError(f"Unsupported model type: {self.hparams['model_type']}") + + def get_token_id(self, token: str) -> int: + tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + with open(tokenizer_config_file, "r", encoding="utf-8") as f: + added_tokens_decoder = json.load(f)['added_tokens_decoder'] + for id_, token_data in added_tokens_decoder.items(): + if token_data["content"] == token: + return int(id_) + raise ValueError(f"Token '{token}' not found in tokenizer config.") + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + if hparams["model_type"] == "pixtral": + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PIXTRAL) + self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"]) + + # hidden_act + if hparams["hidden_act"] == "silu": + self.gguf_writer.add_vision_use_silu(True) + elif hparams["hidden_act"] == "gelu": + self.gguf_writer.add_vision_use_gelu(True) + else: + raise ValueError(f"Unsupported hidden_act: {hparams['hidden_act']}") + + # spatial_merge_size + if "spatial_merge_size" in self.global_config: + self.gguf_writer.add_vision_spatial_merge_size(self.global_config["spatial_merge_size"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + n_head = self.hparams["num_attention_heads"] + n_kv_head = n_head + + if name.startswith("multi_modal_projector.") or name.startswith("vision_tower."): + # process vision tensors + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) + return [(self.map_tensor_name(name), data_torch)] + + if self.img_break_tok_id > 0 and "embed_tokens.weight" in name: + logger.info(f"Extracting [IMG_BREAK] token embedding from {name}") + # for pixtral model, we need to extract the [IMG_BREAK] token embedding + img_break_embd = data_torch[self.img_break_tok_id] + name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK] + return [(self.map_tensor_name(name), img_break_embd)] + + return [] # skip other tensors + + +@ModelBase.register("Idefics3ForConditionalGeneration", "SmolVLMForConditionalGeneration") +class SmolVLMModel(MmprojModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.hparams["model_type"] == "smolvlm_vision": + # fix for SmolVLM2, missing some keys in config.json + # default values are taken from transformers code + self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1152) + self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16) + self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 3072) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.IDEFICS3) + self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5)) + self.gguf_writer.add_vision_projector_scale_factor(self.global_config.get("scale_factor", 2)) + self.gguf_writer.add_vision_use_gelu(True) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + del bid, new_name, n_dims # unused + if ".embeddings." in name: + return gguf.GGMLQuantizationType.F32 + return False + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + is_vision_tensor = "vision_tower" in name or "vision_model" in name or "model.connector" in name + + if is_vision_tensor: + return [(self.map_tensor_name(name), data_torch)] + + return [] # skip other tensors + + +@ModelBase.register("Llama4ForConditionalGeneration") class Llama4Model(LlamaModel): model_arch = gguf.MODEL_ARCH.LLAMA4 - has_vision: bool = False undo_permute = False - # TODO @ngxson : avoid duplicate this code everywhere by at least support "text_config" - # same with llama, but we need to merge the text_config into the root level of hparams def __init__(self, *args, **kwargs): - hparams = kwargs["hparams"] if "hparams" in kwargs else Model.load_hparams(args[0]) - if "text_config" in hparams: - hparams = {**hparams, **hparams["text_config"]} - kwargs["hparams"] = hparams super().__init__(*args, **kwargs) - if "vision_config" in hparams: - logger.info("Has vision encoder, but it will be ignored") - self.has_vision = True # IMPORTANT: the normal "intermediate_size" is renamed to "intermediate_size_mlp", we need to undo this self.hparams["intermediate_size_moe"] = self.hparams["intermediate_size"] self.hparams["intermediate_size"] = self.hparams["intermediate_size_mlp"] @@ -1809,6 +2153,9 @@ class Llama4Model(LlamaModel): self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size_moe"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): + if name.startswith("language_model."): + name = name.replace("language_model.", "") + # split the gate_up into gate and up if "gate_up_proj" in name: name_up = name.replace("gate_up_proj", "up_proj.weight") @@ -1829,18 +2176,33 @@ class Llama4Model(LlamaModel): return super().modify_tensors(data_torch, name, bid) -@Model.register("Mistral3ForConditionalGeneration") +@ModelBase.register("Llama4ForConditionalGeneration") +class Llama4VisionModel(MmprojModel): + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LLAMA4) + self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams["norm_eps"]) + self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / self.hparams["pixel_shuffle_ratio"])) + assert self.hparams["hidden_act"] == "gelu" + self.gguf_writer.add_vision_use_gelu(True) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + if "multi_modal_projector" in name or "vision_model" in name: + # process vision tensors + if "positional_embedding_vlm" in name and ".weight" not in name: + name += ".weight" + if "multi_modal_projector.linear_1" in name: + # despite the name with number postfix, this is a single fully connected layer + return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MMPROJ_FC], data_torch)] + return [(self.map_tensor_name(name), data_torch)] + return [] + + +@ModelBase.register("Mistral3ForConditionalGeneration") class Mistral3Model(LlamaModel): model_arch = gguf.MODEL_ARCH.LLAMA - # we need to merge the text_config into the root level of hparams - def __init__(self, *args, **kwargs): - hparams = kwargs["hparams"] if "hparams" in kwargs else Model.load_hparams(args[0]) - if "text_config" in hparams: - hparams = {**hparams, **hparams["text_config"]} - kwargs["hparams"] = hparams - super().__init__(*args, **kwargs) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): name = name.replace("language_model.", "") if "multi_modal_projector" in name or "vision_tower" in name: @@ -1848,8 +2210,8 @@ class Mistral3Model(LlamaModel): return super().modify_tensors(data_torch, name, bid) -@Model.register("DeciLMForCausalLM") -class DeciModel(Model): +@ModelBase.register("DeciLMForCausalLM") +class DeciModel(TextModel): model_arch = gguf.MODEL_ARCH.DECI @staticmethod @@ -1884,6 +2246,9 @@ class DeciModel(Model): # if n_heads_in_group is not None, then # _num_kv_heads[il] is num_attention_head // n_heads_in_group and # _num_heads[il] is num_attention_head + # ***dummy layer*** for nemotron 253B + # if n_heads_in_group is None and ffn_mult is None + # then _num_kv_heads[il] is 0 and _num_heads[il] is 0 and _ffn_dims is 0 for il in range(len(_block_configs)): if _block_configs[il]["attention"]["n_heads_in_group"] is None: if _block_configs[il]["attention"]["replace_with_linear"] is True: @@ -1895,7 +2260,10 @@ class DeciModel(Model): else: self._num_kv_heads.append(self.hparams["num_attention_heads"] // _block_configs[il]["attention"]["n_heads_in_group"]) self._num_heads.append(self.hparams["num_attention_heads"]) - _ffn_multipliers.append(_block_configs[il]["ffn"]["ffn_mult"]) + if _block_configs[il]["ffn"]["ffn_mult"] is None: # dummy layer + _ffn_multipliers.append(0.0) + else: + _ffn_multipliers.append(_block_configs[il]["ffn"]["ffn_mult"]) assert self.block_count == len(self._num_kv_heads) assert self.block_count == len(self._num_heads) assert self.block_count == len(_ffn_multipliers) @@ -1949,16 +2317,14 @@ class DeciModel(Model): hparams = self.hparams self.gguf_writer.add_vocab_size(hparams["vocab_size"]) - if "head_dim" in hparams: - rope_dim = hparams["head_dim"] - else: + if (rope_dim := hparams.get("head_dim")) is None: rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] self.gguf_writer.add_rope_dimension_count(rope_dim) - if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: - if self.hparams["rope_scaling"].get("type") == "linear": - self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) - self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) @staticmethod def permute(weights: Tensor, n_head: int, n_head_kv: int | None): @@ -1991,7 +2357,8 @@ class DeciModel(Model): if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): if rope_scaling.get("rope_type", '').lower() == "llama3": base = self.hparams.get("rope_theta", 10000.0) - dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + if (dim := self.hparams.get("head_dim")) is None: + dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) factor = rope_scaling.get("factor", 8.0) @@ -2020,8 +2387,8 @@ class DeciModel(Model): super().prepare_tensors() -@Model.register("BitnetForCausalLM") -class BitnetModel(Model): +@ModelBase.register("BitnetForCausalLM") +class BitnetModel(TextModel): model_arch = gguf.MODEL_ARCH.BITNET def set_vocab(self): @@ -2061,8 +2428,8 @@ class BitnetModel(Model): yield (new_name, data_torch) -@Model.register("GrokForCausalLM") -class GrokModel(Model): +@ModelBase.register("GrokForCausalLM") +class GrokModel(TextModel): model_arch = gguf.MODEL_ARCH.GROK def set_vocab(self): @@ -2114,8 +2481,8 @@ class GrokModel(Model): return [(self.map_tensor_name(name), data_torch)] -@Model.register("DbrxForCausalLM") -class DbrxModel(Model): +@ModelBase.register("DbrxForCausalLM") +class DbrxModel(TextModel): model_arch = gguf.MODEL_ARCH.DBRX def set_gguf_parameters(self): @@ -2183,8 +2550,8 @@ class DbrxModel(Model): return n_dims > 1 -@Model.register("MiniCPMForCausalLM") -class MiniCPMModel(Model): +@ModelBase.register("MiniCPMForCausalLM") +class MiniCPMModel(TextModel): model_arch = gguf.MODEL_ARCH.MINICPM def set_gguf_parameters(self): @@ -2198,10 +2565,10 @@ class MiniCPMModel(Model): logit_scale = self.hparams["hidden_size"] / self.hparams["dim_model_base"] self.gguf_writer.add_logit_scale(logit_scale) logger.info(f"gguf: (minicpm) logit_scale = {logit_scale}") - if self.hparams.get("rope_scaling") is not None: - if self.hparams["rope_scaling"].get("type") == "longrope": - self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LONGROPE) - logger.info(f"gguf: (minicpm) rope_scaling_type = {gguf.RopeScalingType.LONGROPE}") + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "longrope": + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LONGROPE) + logger.info(f"gguf: (minicpm) rope_scaling_type = {gguf.RopeScalingType.LONGROPE}") def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: rope_dims = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] @@ -2238,8 +2605,8 @@ class MiniCPMModel(Model): return [(self.map_tensor_name(name), data_torch)] -@Model.register("MiniCPM3ForCausalLM") -class MiniCPM3Model(Model): +@ModelBase.register("MiniCPM3ForCausalLM") +class MiniCPM3Model(TextModel): model_arch = gguf.MODEL_ARCH.MINICPM3 def set_gguf_parameters(self): @@ -2291,8 +2658,8 @@ class MiniCPM3Model(Model): ) -@Model.register("QWenLMHeadModel") -class QwenModel(Model): +@ModelBase.register("QWenLMHeadModel") +class QwenModel(TextModel): model_arch = gguf.MODEL_ARCH.QWEN @staticmethod @@ -2333,8 +2700,8 @@ class QwenModel(Model): self.gguf_writer.add_file_type(self.ftype) -@Model.register("Qwen2ForCausalLM") -class Qwen2Model(Model): +@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration") +class Qwen2Model(TextModel): model_arch = gguf.MODEL_ARCH.QWEN2 def set_vocab(self): @@ -2345,15 +2712,32 @@ class Qwen2Model(Model): def set_gguf_parameters(self): super().set_gguf_parameters() - if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: - if self.hparams["rope_scaling"].get("type") == "yarn": - self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) - self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) - self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"]) + self._try_set_pooling_type() + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if self.hf_arch == "Qwen2Model": + name = f"model.{name}" # map to Qwen2ForCausalLM tensors + if "language_model." in name: + name = name.replace("language_model.", "") # for InternVL + if name.startswith("mlp") or name.startswith("multi_modal_projector") \ + or name.startswith("vision_model") or name.startswith("audio_tower"): + # skip vision and audio tensors + return [] + yield from super().modify_tensors(data_torch, name, bid) -@Model.register("Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration") -class Qwen2VLModel(Model): +@ModelBase.register( + "Qwen2VLModel", + "Qwen2VLForConditionalGeneration", + "Qwen2_5_VLForConditionalGeneration", + "Qwen2_5OmniModel", +) +class Qwen2VLModel(TextModel): model_arch = gguf.MODEL_ARCH.QWEN2VL def set_gguf_parameters(self): @@ -2368,15 +2752,217 @@ class Qwen2VLModel(Model): except FileNotFoundError: self._set_vocab_gpt2() - def get_tensors(self) -> Iterator[tuple[str, Tensor]]: - for name, data in super().get_tensors(): - if name.startswith("visual."): - continue - yield name, data + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + if name.startswith("thinker."): + name = name.replace("thinker.", "") + if name.startswith("visual") or name.startswith("audio") or \ + name.startswith("talker") or name.startswith("token2wav"): + # skip multimodal tensors + return [] + return [(self.map_tensor_name(name), data_torch)] -@Model.register("WavTokenizerDec") -class WavTokenizerDecModel(Model): +@ModelBase.register("Qwen2VLModel", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration") +class Qwen2VLVisionModel(MmprojModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.hparams_vision is not None + self.hparams_vision["image_size"] = self.hparams_vision.get("image_size", 560) + # rename config.json values + self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads") + self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth") + if "embed_dim" in self.hparams_vision: # qwen2vl + self.hparams_vision["intermediate_size"] = self.hparams_vision.get("hidden_size") + self.hparams_vision["hidden_size"] = self.hparams_vision.get("embed_dim") + + def set_gguf_parameters(self): + super().set_gguf_parameters() + assert self.hparams_vision is not None + hparams = self.hparams_vision + model_type = self.global_config['model_type'] + if model_type == 'qwen2_vl': + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN2VL) + elif model_type == 'qwen2_5_vl' or model_type == 'qwen2_5_omni': + if model_type == 'qwen2_5_omni': + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25O) + else: + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25VL) + self.gguf_writer.add_vision_use_silu(True) + # find n_wa_pattern (window attention pattern) + fullatt_block_indexes = hparams.get("fullatt_block_indexes") + assert fullatt_block_indexes is not None, "fullatt_block_indexes is required for qwen2_5_vl" + n_wa_pattern = fullatt_block_indexes[0] + 1 + # validate n_wa_pattern + for i in range(1, len(fullatt_block_indexes)): + if fullatt_block_indexes[i] - fullatt_block_indexes[i - 1] != n_wa_pattern: + raise ValueError(f"Invalid fullatt_block_indexes: {fullatt_block_indexes}") + self.gguf_writer.add_vision_n_wa_pattern(n_wa_pattern) + else: + raise ValueError(f"Unknown QwenVL model type: {self.global_config['model_type']}") + # default values below are taken from HF tranformers code + self.gguf_writer.add_vision_attention_layernorm_eps(self.global_config.get("rms_norm_eps", 1e-6)) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + del bid, name, n_dims # unused + if ".patch_embd." in new_name: + return gguf.GGMLQuantizationType.F16 + if ".position_embd." in new_name: + return gguf.GGMLQuantizationType.F32 + return False + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + if name.startswith("visual."): + # process visual tensors + # split QKV tensors if needed + if ".qkv." in name: + if data_torch.ndim == 2: # weight + c3, _ = data_torch.shape + else: # bias + c3 = data_torch.shape[0] + assert c3 % 3 == 0 + c = c3 // 3 + wq = data_torch[:c] + wk = data_torch[c: c * 2] + wv = data_torch[c * 2:] + return [ + (self.map_tensor_name(name.replace("qkv", "q")), wq), + (self.map_tensor_name(name.replace("qkv", "k")), wk), + (self.map_tensor_name(name.replace("qkv", "v")), wv), + ] + elif 'patch_embed.proj.weight' in name: + # split Conv3D into Conv2Ds + c1, c2, kt, kh, kw = data_torch.shape + del c1, c2, kh, kw # unused + assert kt == 2, "Current implmentation only support temporal_patch_size of 2" + return [ + (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight" , data_torch[:, :, 0, ...]), + (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", data_torch[:, :, 1, ...]), + ] + else: + return [(self.map_tensor_name(name), data_torch)] + return [] # skip other tensors + + +@ModelBase.register("Qwen2_5OmniModel") +class Qwen25OmniModel(Qwen2VLVisionModel): + has_vision_encoder = True + has_audio_encoder = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.hparams_audio is not None + self.hparams_audio["hidden_size"] = self.hparams_audio["d_model"] + self.hparams_audio["intermediate_size"] = self.hparams_audio["encoder_ffn_dim"] + self.hparams_audio["num_attention_heads"] = self.hparams_audio["encoder_attention_heads"] + + def set_gguf_parameters(self): + super().set_gguf_parameters() + assert self.hparams_audio is not None + self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["num_mel_bins"]) + self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams_audio.get("layer_norm_eps", 1e-5)) + + def get_vision_config(self) -> dict[str, Any] | None: + return self.global_config["thinker_config"].get("vision_config") + + def get_audio_config(self) -> dict[str, Any] | None: + return self.global_config["thinker_config"].get("audio_config") + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + # SinusoidsPositionEmbedding + assert self.hparams_audio is not None + max_timescale = 10000 + length = 1500 + channels = self.hparams_audio["hidden_size"] + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float()) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + pos_embd = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1).to(dtype=torch.float32) + yield ("audio_tower.embed_positions.weight", pos_embd) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + del bid, new_name, n_dims # unused + if ".conv" in name and ".weight" in name: + return gguf.GGMLQuantizationType.F16 + return False + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.startswith("thinker."): + name = name.replace("thinker.", "") + + if name.startswith("audio_tower"): + # process audio tensors + if "conv1.bias" in name or "conv2.bias" in name: + # transpose conv1 and conv2 bias + data_torch = data_torch.unsqueeze(-1) + if "audio_bos_eos_token" in name: + # this tensor is left unused in transformers code + # https://github.com/huggingface/transformers/blob/6e3063422c4b1c014aa60c32b9254fd2902f0f28/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py#L1809 + return [] + return [(self.map_tensor_name(name), data_torch)] + + return super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("InternVisionModel") +class InternVisionModel(MmprojModel): + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.INTERNVL) + self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"]) + # hidden_act + if hparams["hidden_act"] == "silu": + self.gguf_writer.add_vision_use_silu(True) + elif hparams["hidden_act"] == "gelu": + self.gguf_writer.add_vision_use_gelu(True) + else: + raise ValueError(f"Unsupported hidden_act: {hparams['hidden_act']}") + # downsample_ratio + downsample_ratio = self.global_config.get("downsample_ratio") + assert downsample_ratio is not None + self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / downsample_ratio)) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + del bid, name, n_dims # unused + if ".patch_embd." in new_name: + return gguf.GGMLQuantizationType.F16 + if ".position_embd." in new_name: + return gguf.GGMLQuantizationType.F32 + return False + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + if name.startswith("vision_model") or name.startswith("mlp"): + # process visual tensors + # correct name + if name.startswith("vision_model"): + name = "vision_tower." + name + if (".ls" in name or "position_embedding" in name) and not name.endswith(".weight"): + name += ".weight" + # split QKV tensors if needed + if ".qkv." in name: + if data_torch.ndim == 2: # weight + c3, _ = data_torch.shape + else: # bias + c3 = data_torch.shape[0] + assert c3 % 3 == 0 + c = c3 // 3 + wq = data_torch[:c] + wk = data_torch[c: c * 2] + wv = data_torch[c * 2:] + return [ + (self.map_tensor_name(name.replace("attn.qkv", "self_attn.q_proj")), wq), + (self.map_tensor_name(name.replace("attn.qkv", "self_attn.k_proj")), wk), + (self.map_tensor_name(name.replace("attn.qkv", "self_attn.v_proj")), wv), + ] + return [(self.map_tensor_name(name), data_torch)] + return [] # skip other tensors + + +@ModelBase.register("WavTokenizerDec") +class WavTokenizerDecModel(TextModel): model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: @@ -2413,8 +2999,8 @@ class WavTokenizerDecModel(Model): self.gguf_writer.add_causal_attention(False) -@Model.register("Qwen2MoeForCausalLM") -class Qwen2MoeModel(Model): +@ModelBase.register("Qwen2MoeForCausalLM") +class Qwen2MoeModel(TextModel): model_arch = gguf.MODEL_ARCH.QWEN2MOE def set_gguf_parameters(self): @@ -2427,6 +3013,13 @@ class Qwen2MoeModel(Model): if (shared_expert_intermediate_size := self.hparams.get('shared_expert_intermediate_size')) is not None: self.gguf_writer.add_expert_shared_feed_forward_length(shared_expert_intermediate_size) logger.info(f"gguf: expert shared feed forward length = {shared_expert_intermediate_size}") + # YaRN is not enabled by default + # To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) _experts: list[dict[str, Tensor]] | None = None @@ -2476,18 +3069,18 @@ class Qwen2MoeModel(Model): raise ValueError(f"Unprocessed experts: {experts}") -@Model.register("Qwen3ForCausalLM") +@ModelBase.register("Qwen3ForCausalLM") class Qwen3Model(Qwen2Model): model_arch = gguf.MODEL_ARCH.QWEN3 -@Model.register("Qwen3MoeForCausalLM") +@ModelBase.register("Qwen3MoeForCausalLM") class Qwen3MoeModel(Qwen2MoeModel): model_arch = gguf.MODEL_ARCH.QWEN3MOE -@Model.register("GPT2LMHeadModel") -class GPT2Model(Model): +@ModelBase.register("GPT2LMHeadModel") +class GPT2Model(TextModel): model_arch = gguf.MODEL_ARCH.GPT2 def set_gguf_parameters(self): @@ -2518,8 +3111,8 @@ class GPT2Model(Model): return tensors -@Model.register("PhiForCausalLM") -class Phi2Model(Model): +@ModelBase.register("PhiForCausalLM") +class Phi2Model(TextModel): model_arch = gguf.MODEL_ARCH.PHI2 def set_gguf_parameters(self): @@ -2542,8 +3135,8 @@ class Phi2Model(Model): self.gguf_writer.add_add_bos_token(False) -@Model.register("Phi3ForCausalLM") -class Phi3MiniModel(Model): +@ModelBase.register("Phi3ForCausalLM") +class Phi3MiniModel(TextModel): model_arch = gguf.MODEL_ARCH.PHI3 def set_vocab(self): @@ -2694,7 +3287,7 @@ class Phi3MiniModel(Model): scale = max_pos_embds / orig_max_pos_embds - rope_scaling_type = rope_scaling.get('type', '').lower() + rope_scaling_type = rope_scaling.get('rope_type', rope_scaling.get('type', '')).lower() if len(rope_scaling_type) == 0: raise KeyError('Missing the required key rope_scaling.type') @@ -2720,7 +3313,7 @@ class Phi3MiniModel(Model): yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32)) -@Model.register("PhiMoEForCausalLM") +@ModelBase.register("PhiMoEForCausalLM") class PhiMoeModel(Phi3MiniModel): model_arch = gguf.MODEL_ARCH.PHIMOE @@ -2777,8 +3370,8 @@ class PhiMoeModel(Phi3MiniModel): raise ValueError(f"Unprocessed experts: {experts}") -@Model.register("PlamoForCausalLM") -class PlamoModel(Model): +@ModelBase.register("PlamoForCausalLM") +class PlamoModel(TextModel): model_arch = gguf.MODEL_ARCH.PLAMO def set_vocab(self): @@ -2825,8 +3418,8 @@ class PlamoModel(Model): return [(new_name, data_torch)] -@Model.register("CodeShellForCausalLM") -class CodeShellModel(Model): +@ModelBase.register("CodeShellForCausalLM") +class CodeShellModel(TextModel): model_arch = gguf.MODEL_ARCH.CODESHELL def set_gguf_parameters(self): @@ -2866,8 +3459,8 @@ class CodeShellModel(Model): return [(new_name, data_torch)] -@Model.register("InternLM2ForCausalLM") -class InternLM2Model(Model): +@ModelBase.register("InternLM2ForCausalLM") +class InternLM2Model(TextModel): model_arch = gguf.MODEL_ARCH.INTERNLM2 def set_vocab(self): @@ -3006,10 +3599,10 @@ class InternLM2Model(Model): self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"]) self.gguf_writer.add_file_type(self.ftype) - if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: - if self.hparams["rope_scaling"].get("type") == "linear": - self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) - self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: num_heads = self.hparams["num_attention_heads"] @@ -3019,6 +3612,11 @@ class InternLM2Model(Model): head_dim = n_embd // num_heads num_groups = num_heads // q_per_kv + name = name.replace("language_model.", "") # InternVL + if name.startswith("mlp") or name.startswith("vision_model"): + # skip visual tensors + return [] + if bid is not None and f"model.layers.{bid}.attention.wqkv" in name: qkv = data_torch @@ -3039,8 +3637,8 @@ class InternLM2Model(Model): return [(self.map_tensor_name(name), data_torch)] -@Model.register("InternLM3ForCausalLM") -class InternLM3Model(Model): +@ModelBase.register("InternLM3ForCausalLM") +class InternLM3Model(TextModel): model_arch = gguf.MODEL_ARCH.LLAMA def set_vocab(self): @@ -3078,20 +3676,22 @@ class InternLM3Model(Model): hparams = self.hparams self.gguf_writer.add_vocab_size(hparams["vocab_size"]) - if "head_dim" in hparams: - rope_dim = hparams["head_dim"] - else: + if (rope_dim := hparams.get("head_dim")) is None: rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] self.gguf_writer.add_rope_dimension_count(rope_dim) - if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: - if self.hparams["rope_scaling"].get("type") == "linear" or self.hparams["rope_scaling"].get("rope_type") == "linear": - self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) - self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") + name = name.replace("language_model.", "") # InternVL + if name.startswith("mlp") or name.startswith("vision_model"): + # skip visual tensors + return [] if name.endswith(("q_proj.weight", "q_proj.bias")): data_torch = LlamaModel.permute(data_torch, n_head, n_head) if name.endswith(("k_proj.weight", "k_proj.bias")): @@ -3099,40 +3699,27 @@ class InternLM3Model(Model): return [(self.map_tensor_name(name), data_torch)] -@Model.register("BertModel", "BertForMaskedLM", "CamembertModel") -class BertModel(Model): +@ModelBase.register("BertModel", "BertForMaskedLM", "CamembertModel", "BertForSequenceClassification") +class BertModel(TextModel): model_arch = gguf.MODEL_ARCH.BERT def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.vocab_size = None + if cls_out_labels := self.hparams.get("id2label"): + if len(cls_out_labels) == 2 and cls_out_labels[0] == "LABEL_0": + # Remove dummy labels added by AutoConfig + cls_out_labels = None + self.cls_out_labels = cls_out_labels + def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_causal_attention(False) + self._try_set_pooling_type() - # get pooling path - pooling_path = None - module_path = self.dir_model / "modules.json" - if module_path.is_file(): - with open(module_path, encoding="utf-8") as f: - modules = json.load(f) - for mod in modules: - if mod["type"] == "sentence_transformers.models.Pooling": - pooling_path = mod["path"] - break - - # get pooling type - if pooling_path is not None: - with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f: - pooling = json.load(f) - if pooling["pooling_mode_mean_tokens"]: - pooling_type = gguf.PoolingType.MEAN - elif pooling["pooling_mode_cls_token"]: - pooling_type = gguf.PoolingType.CLS - else: - raise NotImplementedError("Only MEAN and CLS pooling types supported") - self.gguf_writer.add_pooling_type(pooling_type) + if self.cls_out_labels: + self.gguf_writer.add_classifier_output_labels([v for k, v in sorted(self.cls_out_labels.items())]) def set_vocab(self): tokens, toktypes, tokpre = self.get_vocab_base() @@ -3184,10 +3771,178 @@ class BertModel(Model): if name.startswith("cls.seq_relationship"): return [] + if self.cls_out_labels: + # For BertForSequenceClassification (direct projection layer) + if name == "classifier.weight": + name = "classifier.out_proj.weight" + + if name == "classifier.bias": + name = "classifier.out_proj.bias" + return [(self.map_tensor_name(name), data_torch)] + def _xlmroberta_tokenizer_init(self) -> None: + # we need the pad_token_id to know how to chop down position_embd matrix + if (pad_token_id := self.hparams.get("pad_token_id")) is not None: + self._position_offset = 1 + pad_token_id + if "max_position_embeddings" in self.hparams: + self.hparams["max_position_embeddings"] -= self._position_offset + else: + self._position_offset = None -@Model.register("RobertaModel") + def _xlmroberta_set_vocab(self) -> None: + # to avoid TypeError: Descriptors cannot be created directly + # exception when importing sentencepiece_model_pb2 + os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + from sentencepiece import SentencePieceProcessor + from sentencepiece import sentencepiece_model_pb2 as model + + tokenizer_path = self.dir_model / 'sentencepiece.bpe.model' + + tokenizer_json = {} + tokenizer_config_json = {} + if not tokenizer_path.is_file(): + tokenizer_path = self.dir_model / 'tokenizer.json' + tokenizer_config_path = self.dir_model / 'tokenizer_config.json' + + if not tokenizer_path.is_file(): + raise FileNotFoundError(f"File not found: {tokenizer_path}") + + from base64 import b64decode + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model) + + with open(tokenizer_path, "r", encoding="utf-8") as fp: + tokenizer_json = json.load(fp) + + if tokenizer_config_path.is_file(): + with open(tokenizer_config_path, "r", encoding="utf-8") as fp: + tokenizer_config_json = json.load(fp) + + add_prefix = tokenizer.add_prefix_space + remove_whitespaces = tokenizer.clean_up_tokenization_spaces + precompiled_charsmap = b64decode(tokenizer_json["normalizer"]["precompiled_charsmap"]) + + vocab_size = max(self.hparams.get("vocab_size", 0), tokenizer.vocab_size) + else: + sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] + sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) + assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM + + add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix + remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces + precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap + + tokenizer = SentencePieceProcessor() + tokenizer.LoadFromFile(str(tokenizer_path)) + + vocab_size = max(self.hparams.get("vocab_size", 0), tokenizer.vocab_size()) + + tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] + scores: list[float] = [-10000.0] * vocab_size + toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size + + if isinstance(tokenizer, SentencePieceProcessor): + for token_id in range(tokenizer.vocab_size()): + piece = tokenizer.IdToPiece(token_id) + text = piece.encode("utf-8") + score = tokenizer.GetScore(token_id) + + toktype = SentencePieceTokenTypes.NORMAL + if tokenizer.IsUnknown(token_id): + toktype = SentencePieceTokenTypes.UNKNOWN + elif tokenizer.IsControl(token_id): + toktype = SentencePieceTokenTypes.CONTROL + elif tokenizer.IsUnused(token_id): + toktype = SentencePieceTokenTypes.UNUSED + elif tokenizer.IsByte(token_id): + toktype = SentencePieceTokenTypes.BYTE + + tokens[token_id] = text + scores[token_id] = score + toktypes[token_id] = toktype + else: + added_vocab = tokenizer.get_added_vocab() + unk_token = tokenizer_config_json.get("unk_token") + unk_token_id = added_vocab.get(unk_token, tokenizer_json["model"].get("unk_id", 3)) + + for token_id in range(tokenizer.vocab_size): + piece = tokenizer._convert_id_to_token(token_id) + if (piece := tokenizer._convert_id_to_token(token_id)) is not None: + text = piece.encode("utf-8") + score = tokenizer_json["model"]["vocab"][token_id][1] + + toktype = SentencePieceTokenTypes.NORMAL + if token_id == unk_token_id: + toktype = SentencePieceTokenTypes.UNKNOWN + elif token_id in tokenizer.all_special_ids: + toktype = SentencePieceTokenTypes.CONTROL + elif token_id in added_vocab.values(): + toktype = SentencePieceTokenTypes.USER_DEFINED + # No reliable way to detect this, but jina doesn't have any + # elif tokenizer.IsByte(token_id): + # toktype = SentencePieceTokenTypes.BYTE + + tokens[token_id] = text + scores[token_id] = score + toktypes[token_id] = toktype + + if isinstance(tokenizer, SentencePieceProcessor): + # realign tokens (see HF tokenizer code) + tokens = [b'', b'', b'', b''] + tokens[3:-1] + scores = [0.0, 0.0, 0.0, 0.0] + scores[3:-1] + toktypes = [ + SentencePieceTokenTypes.CONTROL, + SentencePieceTokenTypes.CONTROL, + SentencePieceTokenTypes.CONTROL, + SentencePieceTokenTypes.UNKNOWN, + ] + toktypes[3:-1] + + if self.model_arch == gguf.MODEL_ARCH.NOMIC_BERT_MOE: + # Add mask token missing from sentencepiece.bpe.model + tokens[250001] = b'' + scores[250001] = 0.0 + toktypes[250001] = SentencePieceTokenTypes.CONTROL + + self.gguf_writer.add_tokenizer_model("t5") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + self.gguf_writer.add_add_space_prefix(add_prefix) + self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1)) + self.gguf_writer.add_remove_extra_whitespaces(remove_whitespaces) + if precompiled_charsmap: + self.gguf_writer.add_precompiled_charsmap(precompiled_charsmap) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + + self.gguf_writer.add_add_bos_token(True) + self.gguf_writer.add_add_eos_token(True) + + +@ModelBase.register("DistilBertModel", "DistilBertForMaskedLM", "DistilBertForSequenceClassification") +class DistilBertModel(BertModel): + model_arch = gguf.MODEL_ARCH.BERT + + def set_gguf_parameters(self): + self.gguf_writer.add_layer_norm_eps(1e-12) + logger.info("gguf: layer norm epsilon = 1e-12") + super().set_gguf_parameters() + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.startswith("distilbert."): + name = name[11:] + + # These layers act as MLM head, so we don't need them + if name.startswith("vocab_"): + return [] + + return super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("RobertaModel", "RobertaForSequenceClassification") class RobertaModel(BertModel): model_arch = gguf.MODEL_ARCH.BERT @@ -3232,24 +3987,41 @@ class RobertaModel(BertModel): return super().modify_tensors(data_torch, name, bid) -@Model.register("NomicBertModel") +@ModelBase.register("NomicBertModel") class NomicBertModel(BertModel): - model_arch = gguf.MODEL_ARCH.NOMIC_BERT + model_arch = gguf.MODEL_ARCH.BERT - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any): + hparams = kwargs.pop("hparams", None) + if hparams is None: + hparams = ModelBase.load_hparams(dir_model) - # the HF config claims n_ctx=8192, but it uses RoPE scaling - self.hparams["n_ctx"] = 2048 + self.is_moe = bool(hparams.get("moe_every_n_layers")) + self.model_arch = gguf.MODEL_ARCH.NOMIC_BERT_MOE if self.is_moe else gguf.MODEL_ARCH.NOMIC_BERT + + super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs) + + self._tokenizer_is_xlmroberta = self._is_tokenizer_xlmroberta() + if self._tokenizer_is_xlmroberta: + self._xlmroberta_tokenizer_init() + + npos, mtp = self.hparams["n_positions"], self.hparams.get("max_trained_positions", 2048) + if npos == 8192 and mtp == 2048: + self.hparams["n_positions"] = 2048 # nomic-embed-text v1 and v1.5 are trained for 2048 tokens. + elif npos == 2048 and mtp == 2048: + self.hparams["n_positions"] = 512 # nomic-embed-text-v2-moe is trained for 512 tokens. + else: + raise ValueError(f"unrecognized parameters: n_positions={npos}, max_trained_positions={mtp}") + + assert self.hparams["activation_function"] == "gelu" if self.is_moe else "swiglu" - # SwigLU activation - assert self.hparams["activation_function"] == "swiglu" # this doesn't do anything in the HF version assert self.hparams["causal"] is False - # no bias tensors - assert self.hparams["qkv_proj_bias"] is False - assert self.hparams["mlp_fc1_bias"] is False - assert self.hparams["mlp_fc2_bias"] is False + # no bias tensors unless MoE + assert self.hparams["qkv_proj_bias"] == self.is_moe + assert self.hparams["mlp_fc1_bias"] == self.is_moe + assert self.hparams["mlp_fc2_bias"] == self.is_moe + # norm at end of layer assert self.hparams["prenorm"] is False # standard RoPE @@ -3257,107 +4029,84 @@ class NomicBertModel(BertModel): assert self.hparams["rotary_emb_interleaved"] is False assert self.hparams["rotary_emb_scale_base"] is None + def set_vocab(self) -> None: + if self._tokenizer_is_xlmroberta: + return self._xlmroberta_set_vocab() + return super().set_vocab() + + def modify_tensors(self, data_torch: torch.Tensor, name: str, bid: int | None) -> Iterable[tuple[str, torch.Tensor]]: + # If the tensor is an experts bias tensor, skip it by returning an empty list. + if "mlp.experts.bias" in name: + return [] # Explicitly return an empty list. + + if "mlp.experts.mlp.w1" in name: + data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"]) + name += ".weight" + + if "mlp.experts.mlp.w2" in name: + data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"]) + data_torch = data_torch.transpose(1, 2) + name += ".weight" + + return [(self.map_tensor_name(name), data_torch)] + def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"]) + if self.is_moe: + self.gguf_writer.add_moe_every_n_layers(self.hparams["moe_every_n_layers"]) + self.gguf_writer.add_expert_count(self.hparams["num_experts"]) + self.gguf_writer.add_expert_used_count(self.hparams["moe_top_k"]) + + def _is_tokenizer_xlmroberta(self) -> bool: + with open(self.dir_model / "tokenizer.json") as f: + tokenizer_json = json.load(f) + toktyp = tokenizer_json["model"]["type"] + if toktyp == "Unigram": + return True + if toktyp == "WordPiece": + return False + raise ValueError(f"unknown tokenizer: {toktyp}") -@Model.register("XLMRobertaModel", "XLMRobertaForSequenceClassification") +@ModelBase.register("NeoBERT", "NeoBERTLMHead", "NeoBERTForSequenceClassification") +class NeoBert(BertModel): + model_arch = gguf.MODEL_ARCH.NEO_BERT + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + # NeoBERT uses 2/3 of the intermediate size as feed forward length + self.gguf_writer.add_feed_forward_length(int(2 * self.hparams["intermediate_size"] / 3)) + self.gguf_writer.add_rope_freq_base(10000.0) # default value for NeoBERT + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + + f_rms_eps = self.hparams.get("norm_eps", 1e-6) # default value for NeoBERT + self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps) + logger.info(f"gguf: rms norm epsilon = {f_rms_eps}") + + self.gguf_writer.add_pooling_type(gguf.PoolingType.CLS) # https://huggingface.co/chandar-lab/NeoBERT#how-to-use + + def modify_tensors(self, data_torch, name, bid): + if name.startswith("decoder."): + return [] + + if name.startswith("model."): + name = name[6:] + + return super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification") class XLMRobertaModel(BertModel): model_arch = gguf.MODEL_ARCH.BERT def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - - # we need the pad_token_id to know how to chop down position_embd matrix - if (pad_token_id := self.hparams.get("pad_token_id")) is not None: - self._position_offset = 1 + pad_token_id - if "max_position_embeddings" in self.hparams: - self.hparams["max_position_embeddings"] -= self._position_offset - else: - self._position_offset = None + self._xlmroberta_tokenizer_init() def set_vocab(self): - # to avoid TypeError: Descriptors cannot be created directly - # exception when importing sentencepiece_model_pb2 - os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" - from sentencepiece import SentencePieceProcessor - from sentencepiece import sentencepiece_model_pb2 as model - - tokenizer_path = self.dir_model / 'sentencepiece.bpe.model' - if not tokenizer_path.is_file(): - raise FileNotFoundError(f"File not found: {tokenizer_path}") - - sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] - sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) - assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM - - add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix - remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces - precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap - - tokenizer = SentencePieceProcessor() - tokenizer.LoadFromFile(str(tokenizer_path)) - - vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) - - tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] - scores: list[float] = [-10000.0] * vocab_size - toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size - - for token_id in range(tokenizer.vocab_size()): - piece = tokenizer.IdToPiece(token_id) - text = piece.encode("utf-8") - score = tokenizer.GetScore(token_id) - - toktype = SentencePieceTokenTypes.NORMAL - if tokenizer.IsUnknown(token_id): - toktype = SentencePieceTokenTypes.UNKNOWN - elif tokenizer.IsControl(token_id): - toktype = SentencePieceTokenTypes.CONTROL - elif tokenizer.IsUnused(token_id): - toktype = SentencePieceTokenTypes.UNUSED - elif tokenizer.IsByte(token_id): - toktype = SentencePieceTokenTypes.BYTE - - tokens[token_id] = text - scores[token_id] = score - toktypes[token_id] = toktype - - if vocab_size > len(tokens): - pad_count = vocab_size - len(tokens) - logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]") - for i in range(1, pad_count + 1): - tokens.append(bytes(f"[PAD{i}]", encoding="utf-8")) - scores.append(-1000.0) - toktypes.append(SentencePieceTokenTypes.UNUSED) - - # realign tokens (see HF tokenizer code) - tokens = [b'', b'', b'', b''] + tokens[3:-1] - scores = [0.0, 0.0, 0.0, 0.0] + scores[3:-1] - toktypes = [ - SentencePieceTokenTypes.CONTROL, - SentencePieceTokenTypes.CONTROL, - SentencePieceTokenTypes.CONTROL, - SentencePieceTokenTypes.UNKNOWN, - ] + toktypes[3:-1] - - self.gguf_writer.add_tokenizer_model("t5") - self.gguf_writer.add_tokenizer_pre("default") - self.gguf_writer.add_token_list(tokens) - self.gguf_writer.add_token_scores(scores) - self.gguf_writer.add_token_types(toktypes) - self.gguf_writer.add_add_space_prefix(add_prefix) - self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1)) - self.gguf_writer.add_remove_extra_whitespaces(remove_whitespaces) - if precompiled_charsmap: - self.gguf_writer.add_precompiled_charsmap(precompiled_charsmap) - - special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) - special_vocab.add_to_gguf(self.gguf_writer) - - self.gguf_writer.add_add_bos_token(True) - self.gguf_writer.add_add_eos_token(True) + self._xlmroberta_set_vocab() def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # if name starts with "roberta.", remove the prefix @@ -3373,8 +4122,8 @@ class XLMRobertaModel(BertModel): return super().modify_tensors(data_torch, name, bid) -@Model.register("GemmaForCausalLM") -class GemmaModel(Model): +@ModelBase.register("GemmaForCausalLM") +class GemmaModel(TextModel): model_arch = gguf.MODEL_ARCH.GEMMA def set_vocab(self): @@ -3424,8 +4173,8 @@ class GemmaModel(Model): return [(self.map_tensor_name(name), data_torch)] -@Model.register("Gemma2ForCausalLM") -class Gemma2Model(Model): +@ModelBase.register("Gemma2ForCausalLM") +class Gemma2Model(TextModel): model_arch = gguf.MODEL_ARCH.GEMMA2 def set_vocab(self): @@ -3471,27 +4220,9 @@ class Gemma2Model(Model): return [(self.map_tensor_name(name), data_torch)] -@Model.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration") -class Gemma3Model(Model): +@ModelBase.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration") +class Gemma3Model(TextModel): model_arch = gguf.MODEL_ARCH.GEMMA3 - has_vision: bool = False - - # we need to merge the text_config into the root level of hparams - def __init__(self, *args, **kwargs): - hparams = kwargs["hparams"] if "hparams" in kwargs else Model.load_hparams(args[0]) - if "text_config" in hparams: - hparams = {**hparams, **hparams["text_config"]} - kwargs["hparams"] = hparams - super().__init__(*args, **kwargs) - if "vision_config" in hparams: - logger.info("Has vision encoder, but it will be ignored") - self.has_vision = True - - def write(self): - super().write() - if self.has_vision: - logger.info("NOTE: this script only convert the language model to GGUF") - logger.info(" for the vision model, please use gemma3_convert_encoder_to_gguf.py") def set_vocab(self): self._set_vocab_sentencepiece() @@ -3529,10 +4260,10 @@ class Gemma3Model(Model): if name.startswith("language_model."): name = name.replace("language_model.", "") + elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \ - or name.startswith("multimodal_projector.") or name.startswith("vision_model."): # this is for old HF model, should be removed later - # ignore vision tensors - return [] + or name.startswith("multimodal_projector.") or name.startswith("vision_model."): + return [] # skip vision tensors # remove OOV (out-of-vocabulary) rows in token_embd if "embed_tokens.weight" in name: @@ -3548,13 +4279,65 @@ class Gemma3Model(Model): return [(self.map_tensor_name(name), data_torch)] -@Model.register("Starcoder2ForCausalLM") -class StarCoder2Model(Model): +@ModelBase.register("Gemma3ForConditionalGeneration") +class Gemma3VisionModel(MmprojModel): + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GEMMA3) + # default values below are taken from HF tranformers code + self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6)) + self.gguf_writer.add_vision_use_gelu(True) + # calculate proj_scale_factor (used by tinygemma3 test model) + image_seq_length = self.preprocessor_config.get("image_seq_length", 256) + n_per_side = int(image_seq_length ** 0.5) + image_size = self.hparams["image_size"] + patch_size = self.hparams["patch_size"] + proj_scale_factor = (image_size // patch_size) // n_per_side + if proj_scale_factor > 0 and proj_scale_factor != 4: + # we only need to write this if it's not the default value + # in this case, we are converting a test model + self.gguf_writer.add_vision_projector_scale_factor(proj_scale_factor) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + del bid, new_name, n_dims # unused + # related to https://github.com/ggml-org/llama.cpp/issues/13025 + if "input_projection" in name: + return gguf.GGMLQuantizationType.F16 + if ".embeddings." in name: + return gguf.GGMLQuantizationType.F32 + return False + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if "vision_model.head." in name: + return [] # skip redundant tensors for tinygemma3 + + if name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \ + or name.startswith("multimodal_projector.") or name.startswith("vision_model."): + # process vision tensors + name = name.replace("_weight", ".weight") + + # correct norm value ; only this "soft_emb_norm" need to be corrected as it's part of Gemma projector + # the other norm values are part of SigLIP model, and they are already correct + # ref code: Gemma3RMSNorm + if "soft_emb_norm.weight" in name: + logger.info(f"Correcting norm value for '{name}'") + data_torch = data_torch + 1 + + return [(self.map_tensor_name(name), data_torch)] + + return [] # skip other tensors + + +@ModelBase.register("Starcoder2ForCausalLM") +class StarCoder2Model(TextModel): model_arch = gguf.MODEL_ARCH.STARCODER2 -@Model.register("Rwkv6ForCausalLM") -class Rwkv6Model(Model): +@ModelBase.register("Rwkv6ForCausalLM") +class Rwkv6Model(TextModel): model_arch = gguf.MODEL_ARCH.RWKV6 def set_vocab(self): @@ -3626,7 +4409,7 @@ class Rwkv6Model(Model): yield (new_name, data_torch) -@Model.register("RWKV6Qwen2ForCausalLM") +@ModelBase.register("RWKV6Qwen2ForCausalLM") class RWKV6Qwen2Model(Rwkv6Model): model_arch = gguf.MODEL_ARCH.RWKV6QWEN2 @@ -3680,8 +4463,8 @@ class RWKV6Qwen2Model(Rwkv6Model): yield (new_name, data) -@Model.register("Rwkv7ForCausalLM", "RWKV7ForCausalLM") -class Rwkv7Model(Model): +@ModelBase.register("Rwkv7ForCausalLM", "RWKV7ForCausalLM") +class Rwkv7Model(TextModel): model_arch = gguf.MODEL_ARCH.RWKV7 def set_vocab(self): @@ -3799,7 +4582,7 @@ class Rwkv7Model(Model): yield (new_name, data_torch) -@Model.register("RwkvHybridForCausalLM") +@ModelBase.register("RwkvHybridForCausalLM") class ARwkv7Model(Rwkv7Model): model_arch = gguf.MODEL_ARCH.ARWKV7 @@ -3842,8 +4625,8 @@ class ARwkv7Model(Rwkv7Model): self.gguf_writer.add_head_count(0) -@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM") -class MambaModel(Model): +@ModelBase.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM") +class MambaModel(TextModel): model_arch = gguf.MODEL_ARCH.MAMBA def set_vocab(self): @@ -3920,8 +4703,8 @@ class MambaModel(Model): return [(new_name, data_torch)] -@Model.register("CohereForCausalLM") -class CommandR2Model(Model): +@ModelBase.register("CohereForCausalLM") +class CommandR2Model(TextModel): model_arch = gguf.MODEL_ARCH.COMMAND_R def __init__(self, *args, **kwargs): @@ -3938,8 +4721,8 @@ class CommandR2Model(Model): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) -@Model.register("Cohere2ForCausalLM") -class Cohere2Model(Model): +@ModelBase.register("Cohere2ForCausalLM") +class Cohere2Model(TextModel): model_arch = gguf.MODEL_ARCH.COHERE2 def set_gguf_parameters(self): @@ -3956,9 +4739,9 @@ class Cohere2Model(Model): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) -@Model.register("OlmoForCausalLM") -@Model.register("OLMoForCausalLM") -class OlmoModel(Model): +@ModelBase.register("OlmoForCausalLM") +@ModelBase.register("OLMoForCausalLM") +class OlmoModel(TextModel): model_arch = gguf.MODEL_ARCH.OLMO def set_gguf_parameters(self): @@ -3984,13 +4767,13 @@ class OlmoModel(Model): return [(self.map_tensor_name(name), data_torch)] -@Model.register("Olmo2ForCausalLM") -class Olmo2Model(Model): +@ModelBase.register("Olmo2ForCausalLM") +class Olmo2Model(TextModel): model_arch = gguf.MODEL_ARCH.OLMO2 -@Model.register("OlmoeForCausalLM") -class OlmoeModel(Model): +@ModelBase.register("OlmoeForCausalLM") +class OlmoeModel(TextModel): model_arch = gguf.MODEL_ARCH.OLMOE def set_gguf_parameters(self): @@ -4049,29 +4832,10 @@ class OlmoeModel(Model): raise ValueError(f"Unprocessed experts: {experts}") -@Model.register("JinaBertModel", "JinaBertForMaskedLM") +@ModelBase.register("JinaBertModel", "JinaBertForMaskedLM") class JinaBertV2Model(BertModel): model_arch = gguf.MODEL_ARCH.JINA_BERT_V2 - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.intermediate_size = self.hparams["intermediate_size"] - - def get_tensors(self): - for name, data in super().get_tensors(): - if 'gated_layer' in name: - d1 = data[:self.intermediate_size, :] - name1 = name.replace('gated_layers', 'gated_layers_w') - name1 = name1.replace('up_gated_layer', 'gated_layers_v') - d2 = data[self.intermediate_size:, :] - name2 = name.replace('gated_layers', 'gated_layers_v') - name2 = name2.replace('up_gated_layer', 'gated_layers_w') - yield name1, d1 - yield name2, d2 - continue - - yield name, data - def set_vocab(self): tokenizer_class = 'BertTokenizer' with open(self.dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f: @@ -4087,17 +4851,9 @@ class JinaBertV2Model(BertModel): self.gguf_writer.add_add_bos_token(True) self.gguf_writer.add_add_eos_token(True) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - # if name starts with "bert.", remove the prefix - # e.g. https://huggingface.co/jinaai/jina-reranker-v1-tiny-en - if name.startswith("bert."): - name = name[5:] - return super().modify_tensors(data_torch, name, bid) - - -@Model.register("OpenELMForCausalLM") -class OpenELMModel(Model): +@ModelBase.register("OpenELMForCausalLM") +class OpenELMModel(TextModel): model_arch = gguf.MODEL_ARCH.OPENELM @staticmethod @@ -4171,8 +4927,8 @@ class OpenELMModel(Model): yield (self.map_tensor_name(name), data_torch) -@Model.register("ArcticForCausalLM") -class ArcticModel(Model): +@ModelBase.register("ArcticForCausalLM") +class ArcticModel(TextModel): model_arch = gguf.MODEL_ARCH.ARCTIC def set_vocab(self): @@ -4322,8 +5078,8 @@ class ArcticModel(Model): raise ValueError(f"Unprocessed experts: {experts}") -@Model.register("DeepseekForCausalLM") -class DeepseekModel(Model): +@ModelBase.register("DeepseekForCausalLM") +class DeepseekModel(TextModel): model_arch = gguf.MODEL_ARCH.DEEPSEEK def set_vocab(self): @@ -4335,9 +5091,7 @@ class DeepseekModel(Model): def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams - if "head_dim" in hparams: - rope_dim = hparams["head_dim"] - else: + if (rope_dim := hparams.get("head_dim")) is None: rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] self.gguf_writer.add_rope_dimension_count(rope_dim) @@ -4413,15 +5167,19 @@ class DeepseekModel(Model): raise ValueError(f"Unprocessed experts: {experts}") -@Model.register("DeepseekV2ForCausalLM") -@Model.register("DeepseekV3ForCausalLM") -class DeepseekV2Model(Model): +@ModelBase.register("DeepseekV2ForCausalLM") +@ModelBase.register("DeepseekV3ForCausalLM") +class DeepseekV2Model(TextModel): model_arch = gguf.MODEL_ARCH.DEEPSEEK2 def set_vocab(self): self._set_vocab_gpt2() def set_gguf_parameters(self): + + # note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group) + self.hparams["num_key_value_heads"] = 1 + super().set_gguf_parameters() hparams = self.hparams @@ -4430,8 +5188,13 @@ class DeepseekV2Model(Model): if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None: self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"]) - self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) - self.gguf_writer.add_value_length(hparams["v_head_dim"]) + + # note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA + self.gguf_writer.add_key_length(hparams["kv_lora_rank"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length(hparams["kv_lora_rank"]) + self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) + self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) @@ -4447,12 +5210,12 @@ class DeepseekV2Model(Model): self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) - if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: - if self.hparams["rope_scaling"].get("type") == "yarn": - self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) - self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) - self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"]) - self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * hparams["rope_scaling"]["mscale_all_dim"]) + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_scaling["mscale_all_dim"]) _experts: list[dict[str, Tensor]] | None = None @@ -4500,6 +5263,26 @@ class DeepseekV2Model(Model): else: return [] + # note: MLA with the absorption optimization, needs these two split and k_b_proj transposed + if name.endswith("kv_b_proj.weight"): + name_kb = name.replace("kv_b_proj", "k_b_proj") + name_vb = name.replace("kv_b_proj", "v_b_proj") + + n_head_kv = self.hparams["num_key_value_heads"] + v_head_dim = self.hparams["v_head_dim"] + qk_nope_head_dim = self.hparams["qk_nope_head_dim"] + + assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim) + + kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1]) + k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1) + k_b = k_b.transpose(1, 2) + + return [ + (self.map_tensor_name(name_kb), k_b), + (self.map_tensor_name(name_vb), v_b) + ] + return [(self.map_tensor_name(name), data_torch)] def prepare_tensors(self): @@ -4512,8 +5295,36 @@ class DeepseekV2Model(Model): raise ValueError(f"Unprocessed experts: {experts}") -@Model.register("PLMForCausalLM") -class PLMModel(Model): +@ModelBase.register("Dots1ForCausalLM") +class Dots1Model(Qwen2MoeModel): + model_arch = gguf.MODEL_ARCH.DOTS1 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.hparams["num_experts"] = self.hparams["n_routed_experts"] + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_leading_dense_block_count(self.hparams["first_k_dense_replace"]) + self.gguf_writer.add_expert_shared_count(self.hparams["n_shared_experts"]) + self.gguf_writer.add_expert_weights_scale(self.hparams["routed_scaling_factor"]) + self.gguf_writer.add_expert_weights_norm(self.hparams["norm_topk_prob"]) + + if self.hparams["scoring_func"] == "noaux_tc": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + else: + raise ValueError(f"Unsupported scoring_func value: {self.hparams['scoring_func']}") + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): + if name.endswith("e_score_correction_bias"): + name = name.replace("e_score_correction_bias", "e_score_correction.bias") + if "shared_experts" in name: + return [(self.map_tensor_name(name), data_torch)] + return super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("PLMForCausalLM") +class PLMModel(TextModel): model_arch = gguf.MODEL_ARCH.PLM def set_vocab(self): @@ -4535,11 +5346,11 @@ class PLMModel(Model): super().prepare_tensors() -@Model.register("T5WithLMHeadModel") -@Model.register("T5ForConditionalGeneration") -@Model.register("MT5ForConditionalGeneration") -@Model.register("UMT5ForConditionalGeneration") -class T5Model(Model): +@ModelBase.register("T5WithLMHeadModel") +@ModelBase.register("T5ForConditionalGeneration") +@ModelBase.register("MT5ForConditionalGeneration") +@ModelBase.register("UMT5ForConditionalGeneration") +class T5Model(TextModel): model_arch = gguf.MODEL_ARCH.T5 def __init__(self, *args, **kwargs): @@ -4678,8 +5489,8 @@ class T5Model(Model): return [(self.map_tensor_name(name), data_torch)] -@Model.register("T5EncoderModel") -class T5EncoderModel(Model): +@ModelBase.register("T5EncoderModel") +class T5EncoderModel(TextModel): model_arch = gguf.MODEL_ARCH.T5ENCODER def __init__(self, *args, **kwargs): @@ -4817,8 +5628,8 @@ class T5EncoderModel(Model): return [(self.map_tensor_name(name), data_torch)] -@Model.register("JAISLMHeadModel") -class JaisModel(Model): +@ModelBase.register("JAISLMHeadModel") +class JaisModel(TextModel): model_arch = gguf.MODEL_ARCH.JAIS def __init__(self, *args, **kwargs): @@ -4900,24 +5711,39 @@ class JaisModel(Model): self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias) -@Model.register("Glm4ForCausalLM") -class Glm4Model(Model): +@ModelBase.register("Glm4ForCausalLM") +class Glm4Model(TextModel): model_arch = gguf.MODEL_ARCH.GLM4 def set_vocab(self): - self._set_vocab_gpt2() + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + tokens, toktypes, tokpre = self.get_vocab_base() + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) + special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): super().set_gguf_parameters() - if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]: - if self.hparams["rope_scaling"].get("type") == "yarn": - self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) - self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"]) - self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"]) + rope_dim = self.hparams["head_dim"] + self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))) + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) -@Model.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration") -class ChatGLMModel(Model): +@ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration") +class ChatGLMModel(TextModel): model_arch = gguf.MODEL_ARCH.CHATGLM def set_vocab_chatglm3(self): @@ -5071,8 +5897,8 @@ class ChatGLMModel(Model): return [(self.map_tensor_name(name), data_torch)] -@Model.register("NemotronForCausalLM") -class NemotronModel(Model): +@ModelBase.register("NemotronForCausalLM") +class NemotronModel(TextModel): model_arch = gguf.MODEL_ARCH.NEMOTRON def set_vocab(self): @@ -5112,8 +5938,8 @@ class NemotronModel(Model): return [(self.map_tensor_name(name), data_torch)] -@Model.register("ExaoneForCausalLM") -class ExaoneModel(Model): +@ModelBase.register("ExaoneForCausalLM") +class ExaoneModel(TextModel): model_arch = gguf.MODEL_ARCH.EXAONE def set_gguf_parameters(self): @@ -5146,16 +5972,17 @@ class ExaoneModel(Model): rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"], optional=True) rotary_factor = rotary_factor if rotary_factor is not None else 1.0 self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"]))) - if hparams.get("rope_scaling") is not None and "factor" in hparams["rope_scaling"]: - if hparams["rope_scaling"].get("type") == "linear": - self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) - self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"]) + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): if rope_scaling.get("rope_type", '').lower() == "llama3": base = self.hparams.get("rope_theta", 10000.0) - dim = self.hparams.get("head_dim", self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + if (dim := self.hparams.get("head_dim")) is None: + dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) factor = rope_scaling.get("factor", 8.0) @@ -5181,7 +6008,7 @@ class ExaoneModel(Model): yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32)) -@Model.register("GraniteForCausalLM") +@ModelBase.register("GraniteForCausalLM") class GraniteModel(LlamaModel): """Conversion for IBM's GraniteForCausalLM""" model_arch = gguf.MODEL_ARCH.GRANITE @@ -5215,11 +6042,20 @@ class GraniteModel(LlamaModel): logger.info("gguf: (granite) logits_scale = %s", logits_scale) -@Model.register("GraniteMoeForCausalLM") +@ModelBase.register("GraniteMoeForCausalLM", "GraniteMoeSharedForCausalLM") class GraniteMoeModel(GraniteModel): """Conversion for IBM's GraniteMoeForCausalLM""" model_arch = gguf.MODEL_ARCH.GRANITE_MOE + def set_gguf_parameters(self): + """GraniteMoeShared uses GraniteMoe parameters plus the following: + - shared_intermediate_size + """ + super().set_gguf_parameters() + if shared_feed_forward_length := self.hparams.get("shared_intermediate_size"): + self.gguf_writer.add_expert_shared_feed_forward_length(shared_feed_forward_length) + logger.info("gguf: (granitemoeshared) shared_feed_forward_length = %s", shared_feed_forward_length) + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: """In modeling_granitemoe, the JetMoe implementation of parallel experts is used. This essentially merges w1 and w3 into a single tensor with 2x @@ -5230,17 +6066,26 @@ class GraniteMoeModel(GraniteModel): if name.endswith("block_sparse_moe.input_linear.weight"): ffn_dim = self.hparams["intermediate_size"] assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * intermediate_size" - gate, up = data_torch[..., :ffn_dim, :], data_torch[..., ffn_dim:, :] + gate, up = data_torch.split(ffn_dim, dim=-2) return [ (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), gate), (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up), ] + if name.endswith("shared_mlp.input_linear.weight"): + ffn_dim = self.hparams["shared_intermediate_size"] + assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * shared_intermediate_size" + gate, up = data_torch.split(ffn_dim, dim=-2) + return [ + (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), gate), + (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), up), + ] + return super().modify_tensors(data_torch, name, bid) -@Model.register("BailingMoeForCausalLM") -class BailingMoeModel(Model): +@ModelBase.register("BailingMoeForCausalLM") +class BailingMoeModel(TextModel): model_arch = gguf.MODEL_ARCH.BAILINGMOE def set_vocab(self): @@ -5249,10 +6094,17 @@ class BailingMoeModel(Model): def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams - rope_dim = hparams.get("head_dim") or hparams["hidden_size"] // hparams["num_attention_heads"] + if (rope_dim := hparams.get("head_dim")) is None: + rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] self.gguf_writer.add_rope_dimension_count(rope_dim) - self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + else: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) self.gguf_writer.add_vocab_size(hparams["vocab_size"]) self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) @@ -5275,7 +6127,8 @@ class BailingMoeModel(Model): n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") n_embd = self.hparams["hidden_size"] - head_dim = self.hparams.get("head_dim") or n_embd // n_head + if (head_dim := self.hparams.get("head_dim")) is None: + head_dim = n_embd // n_head output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT) @@ -5338,9 +6191,9 @@ class BailingMoeModel(Model): raise ValueError(f"Unprocessed experts: {experts}") -@Model.register("ChameleonForConditionalGeneration") -@Model.register("ChameleonForCausalLM") # obsolete -class ChameleonModel(Model): +@ModelBase.register("ChameleonForConditionalGeneration") +@ModelBase.register("ChameleonForCausalLM") # obsolete +class ChameleonModel(TextModel): model_arch = gguf.MODEL_ARCH.CHAMELEON def set_gguf_parameters(self): @@ -5380,6 +6233,65 @@ class ChameleonModel(Model): return data_torch +@ModelBase.register("UltravoxModel") +class UltravoxModel(TextModel): + model_arch = gguf.MODEL_ARCH.LLAMA # dummy + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + raise NotImplementedError("Ultravox does not have text decoder. Instead, it uses Llama or other models for text. If you want to get the audio encoder, please use --mmproj argument") + + +@ModelBase.register("Qwen2AudioForConditionalGeneration") +class WhisperEncoderModel(MmprojModel): + has_vision_encoder = False # no vision encoder + has_audio_encoder = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.hparams["hidden_size"] = self.hparams["d_model"] + self.hparams["intermediate_size"] = self.hparams["encoder_ffn_dim"] + self.hparams["num_attention_heads"] = self.hparams["encoder_attention_heads"] + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN2A) + self.gguf_writer.add_audio_num_mel_bins(self.hparams["num_mel_bins"]) + self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5)) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + del bid, new_name, n_dims # unused + if ".conv" in name and ".weight" in name: + return gguf.GGMLQuantizationType.F16 + return False + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if name.startswith("language_model."): + # skip language model tensors + return [] + + # prevent clash naming with vision tensors + if name.startswith("multi_modal_projector"): + name = "audio." + name + + if "conv1.bias" in name or "conv2.bias" in name: + # transpose conv1 and conv2 bias + data_torch = data_torch.unsqueeze(-1) + + return [(self.map_tensor_name(name), data_torch)] + + +@ModelBase.register("UltravoxModel") +class UltravoxWhisperEncoderModel(WhisperEncoderModel): + has_vision_encoder = False # no vision encoder + has_audio_encoder = True + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"]) + ###### CONVERSION LOGIC ###### @@ -5525,6 +6437,10 @@ def parse_args() -> argparse.Namespace: "--remote", action="store_true", help="(Experimental) Read safetensors file remotely without downloading to disk. Config and tokenizer files will still be downloaded. To use this feature, you need to specify Hugging Face model repo name instead of a local directory. For example: 'HuggingFaceTB/SmolLM2-1.7B-Instruct'. Note: To access gated repo, set HF_TOKEN environment variable to your Hugging Face token.", ) + parser.add_argument( + "--mmproj", action="store_true", + help="(Experimental) Export multimodal projector (mmproj) for vision models. This will only work on some vision models. A prefix 'mmproj-' will be added to the output file name.", + ) args = parser.parse_args() if not args.print_supported_models and args.model is None: @@ -5550,12 +6466,26 @@ def split_str_to_n_bytes(split_str: str) -> int: return n +def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> str: + # TODO @ngxson : this won't work correctly if the model has both audio & vision encoders + # maybe we should fallback to text model's arch in that case, since not many models have both + text_config = hparams.get("text_config", {}) + vision_config = hparams.get("vision_config", {}) + arch = hparams["architectures"][0] + # if "architectures" is found in the sub-config, use that instead + if model_type == ModelType.TEXT and text_config.get("architectures") is not None: + arch = text_config["architectures"][0] + elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None: + arch = vision_config["architectures"][0] + return arch + + def main() -> None: args = parse_args() if args.print_supported_models: logger.error("Supported models:") - Model.print_registered_models() + ModelBase.print_registered_models() sys.exit(0) if args.verbose: @@ -5602,13 +6532,18 @@ def main() -> None: logger.info(f"Loading model: {dir_model.name}") - hparams = Model.load_hparams(dir_model) + if args.mmproj: + if "mmproj" not in fname_out.name: + fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-") with torch.inference_mode(): output_type = ftype_map[args.outtype] - model_architecture = hparams["architectures"][0] + model_type = ModelType.MMPROJ if args.mmproj else ModelType.TEXT + hparams = ModelBase.load_hparams(dir_model) + model_architecture = get_model_architecture(hparams, model_type) + logger.info(f"Model architecture: {model_architecture}") try: - model_class = Model.from_model_architecture(model_architecture) + model_class = ModelBase.from_model_architecture(model_architecture, model_type=model_type) except NotImplementedError: logger.error(f"Model {model_architecture} is not supported") sys.exit(1) diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 160c9fe0e..2f733f097 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -1,28 +1,6 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# This script downloads the tokenizer models of the specified models from Huggingface and -# generates the get_vocab_base_pre() function for convert_hf_to_gguf.py -# -# This is necessary in order to analyze the type of pre-tokenizer used by the model and -# provide the necessary information to llama.cpp via the GGUF header in order to implement -# the same pre-tokenizer. -# -# ref: https://github.com/ggml-org/llama.cpp/pull/6920 -# -# Instructions: -# -# - Add a new model to the "models" list -# - Run the script with your huggingface token: -# -# python3 convert_hf_to_gguf_update.py -# -# - The convert_hf_to_gguf.py script will have had its get_vocab_base_pre() function updated -# - Update llama.cpp with the new pre-tokenizer if necessary -# -# TODO: generate tokenizer tests for llama.cpp -# - import logging import os import pathlib @@ -32,6 +10,7 @@ import requests import sys import json import shutil +import argparse from hashlib import sha256 from enum import IntEnum, auto @@ -41,6 +20,11 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger("convert_hf_to_gguf_update") sess = requests.Session() +convert_py_pth = pathlib.Path("convert_hf_to_gguf.py") +convert_py = convert_py_pth.read_text(encoding="utf-8") +hf_token_pth = pathlib.Path.home() / ".cache" / "huggingface" / "token" +hf_token = hf_token_pth.read_text(encoding="utf-8").strip() if hf_token_pth.exists() else None + class TOKENIZER_TYPE(IntEnum): SPM = auto() @@ -49,20 +33,49 @@ class TOKENIZER_TYPE(IntEnum): UGM = auto() +DOC_STRING = """ +This script downloads the tokenizer models of the specified models from Huggingface and +generates the get_vocab_base_pre() function for convert_hf_to_gguf.py + +/!\\ It is intended to be used by contributors and is not meant to be run by end users + +This is necessary in order to analyze the type of pre-tokenizer used by the model and +provide the necessary information to llama.cpp via the GGUF header in order to implement +the same pre-tokenizer. + +ref: https://github.com/ggml-org/llama.cpp/pull/6920 + +Instructions: + +- Add a new model to the "models" list +- Run the script with your huggingface token + By default, token will be read from ~/.cache/huggingface/token +- The convert_hf_to_gguf.py script will have had its get_vocab_base_pre() function updated +- Update llama.cpp with the new pre-tokenizer if necessary +""" +# TODO: generate tokenizer tests for llama.cpp + +parser = argparse.ArgumentParser(description=DOC_STRING, formatter_class=argparse.RawTextHelpFormatter) +parser.add_argument( + "--full", action="store_true", + help="download full list of models - make sure you have access to all of them", +) +parser.add_argument( + "hf_token", + help="optional HF token", + nargs="?", +) +args = parser.parse_args() +hf_token = args.hf_token if args.hf_token is not None else hf_token + +if hf_token is None: + logger.error("HF token is required. Please provide it as an argument or set it in ~/.cache/huggingface/token") + sys.exit(1) + # TODO: this string has to exercise as much pre-tokenizer functionality as possible # will be updated with time - contributions welcome CHK_TXT = '\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български \'\'\'\'\'\'```````\"\"\"\"......!!!!!!?????? I\'ve been \'told he\'s there, \'RE you sure? \'M not sure I\'ll make it, \'D you like some tea? We\'Ve a\'lL' -if len(sys.argv) == 2: - token = sys.argv[1] - if not token.startswith("hf_"): - logger.info("Huggingface token seems invalid") - logger.info("Usage: python convert_hf_to_gguf_update.py ") - sys.exit(1) -else: - logger.info("Usage: python convert_hf_to_gguf_update.py ") - sys.exit(1) - # TODO: add models here, base models preferred models = [ {"name": "llama-spm", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/meta-llama/Llama-2-7b-hf", }, @@ -103,7 +116,6 @@ models = [ {"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", }, {"name": "phi-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/microsoft/phi-2", }, {"name": "chameleon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/facebook/chameleon-7b", }, - {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", }, {"name": "roberta-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sentence-transformers/stsb-roberta-base"}, {"name": "gigachat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct"}, {"name": "megrez", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Infinigence/Megrez-3B-Instruct"}, @@ -114,7 +126,17 @@ models = [ {"name": "trillion", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/trillionlabs/Trillion-7B-preview", }, {"name": "bailingmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-lite", }, {"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", }, - {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", }, + {"name": "pixtral", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistral-community/pixtral-12b", }, + {"name": "seed-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", }, +] + +# some models are known to be broken upstream, so we will skip them as exceptions +pre_computed_hashes = [ + # chatglm-bpe has 2 hashes, why? + {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b"}, + {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"}, + {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"}, + {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"}, ] @@ -167,9 +189,29 @@ def download_model(model): if os.path.isfile(save_path): logger.info(f"{name}: File {save_path} already exists - skipping") continue - download_file_with_auth(f"{repo}/resolve/main/{file}", token, save_path) + download_file_with_auth(f"{repo}/resolve/main/{file}", hf_token, save_path) +# get list of existing models and chkhsh from the convert_hf_to_gguf.py file +# returns mapping res --> chkhsh +def get_existing_models(convert_py): + pattern = r'if chkhsh == "([a-f0-9]{64})":\s*\n\s*.*\s*res = "([^"]+)"' + matches = re.findall(pattern, convert_py) + output = {} + for chkhsh, res in matches: + output[res] = chkhsh + return output + + +existing_models = {} +all_models = models.copy() +if not args.full: + # Filter out models that already exist in convert_hf_to_gguf.py + existing_models = get_existing_models(convert_py) + all_models = models.copy() + models = [model for model in all_models if model["name"] not in existing_models] + +logging.info(f"Downloading {len(models)} models...") for model in models: try: download_model(model) @@ -180,9 +222,10 @@ for model in models: # generate the source code for the convert_hf_to_gguf.py:get_vocab_base_pre() function: src_ifs = "" -for model in models: +for model in [*all_models, *pre_computed_hashes]: name = model["name"] tokt = model["tokt"] + chkhsh = model.get("chkhsh") if tokt == TOKENIZER_TYPE.SPM or tokt == TOKENIZER_TYPE.UGM: continue @@ -193,35 +236,44 @@ for model in models: continue # create the tokenizer - try: - if name == "t5": - tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False) - else: - tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}") - except OSError as e: - logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}") - continue # Skip to the next model if the tokenizer can't be loaded + if chkhsh is not None: + # if the model has a pre-computed hash, use it + logger.info(f"Using pre-computed hash for model {name}: {chkhsh}") + elif name in existing_models: + # if the model already exists in convert_hf_to_gguf.py, skip compute hash + chkhsh = existing_models[name] + else: + # otherwise, compute the hash of the tokenizer + try: + logger.info(f"Loading tokenizer from {f'models/tokenizers/{name}'}...") + if name == "t5": + tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False) + else: + tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}") + except OSError as e: + logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}") + continue # Skip to the next model if the tokenizer can't be loaded - chktok = tokenizer.encode(CHK_TXT) - chkhsh = sha256(str(chktok).encode()).hexdigest() + chktok = tokenizer.encode(CHK_TXT) + chkhsh = sha256(str(chktok).encode()).hexdigest() - logger.info(f"model: {name}") - logger.info(f"tokt: {tokt}") - logger.info(f"repo: {model['repo']}") - logger.info(f"chktok: {chktok}") - logger.info(f"chkhsh: {chkhsh}") + logger.info(f"model: {name}") + logger.info(f"tokt: {tokt}") + logger.info(f"repo: {model['repo']}") + logger.info(f"chktok: {chktok}") + logger.info(f"chkhsh: {chkhsh}") - # print the "pre_tokenizer" content from the tokenizer.json - with open(f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8") as f: - cfg = json.load(f) - normalizer = cfg["normalizer"] - logger.info("normalizer: " + json.dumps(normalizer, indent=4)) - pre_tokenizer = cfg["pre_tokenizer"] - logger.info("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4)) - if "ignore_merges" in cfg["model"]: - logger.info("ignore_merges: " + json.dumps(cfg["model"]["ignore_merges"], indent=4)) + # print the "pre_tokenizer" content from the tokenizer.json + with open(f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8") as f: + cfg = json.load(f) + normalizer = cfg["normalizer"] + logger.info("normalizer: " + json.dumps(normalizer, indent=4)) + pre_tokenizer = cfg["pre_tokenizer"] + logger.info("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4)) + if "ignore_merges" in cfg["model"]: + logger.info("ignore_merges: " + json.dumps(cfg["model"]["ignore_merges"], indent=4)) - logger.info("") + logger.info("") src_ifs += f" if chkhsh == \"{chkhsh}\":\n" src_ifs += f" # ref: {model['repo']}\n" @@ -269,8 +321,6 @@ src_func = f""" return res """ -convert_py_pth = pathlib.Path("convert_hf_to_gguf.py") -convert_py = convert_py_pth.read_text(encoding="utf-8") convert_py = re.sub( r"(# Marker: Start get_vocab_base_pre)(.+?)( +# Marker: End get_vocab_base_pre)", lambda m: m.group(1) + src_func + m.group(3), @@ -286,7 +336,7 @@ logger.info("+++ convert_hf_to_gguf.py was updated") tests = [ "ied 4 ½ months", - "Führer", + "Äpfel", "", " ", " ", @@ -365,6 +415,10 @@ for model in models: logger.error(f"Failed to load tokenizer for model {name}. Error: {e}") continue # Skip this model and continue with the next one in the loop + if not os.path.exists(f"models/ggml-vocab-{name}.gguf"): + logger.info(f"Skip vocab files for model {name}, no GGUF file found") + continue + with open(f"models/ggml-vocab-{name}.gguf.inp", "w", encoding="utf-8") as f: for text in tests: f.write(f"{text}") diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index bdc991533..00a6733cb 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -24,7 +24,7 @@ if 'NO_LOCAL_GGUF' not in os.environ: import gguf # reuse model definitions from convert_hf_to_gguf.py -from convert_hf_to_gguf import LazyTorchTensor, Model +from convert_hf_to_gguf import LazyTorchTensor, ModelBase logger = logging.getLogger("lora-to-gguf") @@ -340,11 +340,11 @@ if __name__ == '__main__': sys.exit(1) else: logger.info(f"Loading base model: {dir_base_model.name}") - hparams = Model.load_hparams(dir_base_model) + hparams = ModelBase.load_hparams(dir_base_model) with torch.inference_mode(): try: - model_class = Model.from_model_architecture(hparams["architectures"][0]) + model_class = ModelBase.from_model_architecture(hparams["architectures"][0]) except NotImplementedError: logger.error(f"Model {hparams['architectures'][0]} is not supported") sys.exit(1) diff --git a/docs/backend/CANN.md b/docs/backend/CANN.md old mode 100644 new mode 100755 index 23f10175a..2b001f09a --- a/docs/backend/CANN.md +++ b/docs/backend/CANN.md @@ -8,6 +8,7 @@ - [DataType Supports](#datatype-supports) - [Docker](#docker) - [Linux](#linux) + - [Environment variable setup](#environment-variable-setup) - [TODO](#todo) @@ -56,60 +57,82 @@ The llama.cpp CANN backend is designed to support Ascend NPU. It utilize the abi ## Model Supports -| Model Name | FP16 | Q8_0 | Q4_0 | +| Model Name | FP16 | Q4_0 | Q8_0 | |:----------------------------|:-----:|:----:|:----:| -| AquilaChat2-7B | √ | √ | √ | -| Baichuan-7b | √ | √ | √ | -| Baichuan2-7B-Chat | √ | √ | √ | -| bitnet_b1_58-large | √ | √ | √ | -| bloom-560m | √ | x | √ | -| bloomz-alpaca-560m | √ | x | √ | -| c4ai-command-r-35B-v01 | x | x | x | -| chatglm3-6B | x | x | x | -| chinese-alpaca-2-1.3b | √ | √ | √ | -| CodeShell-7B | √ | √ | √ | -| deepseek-ai_deepseek-coder-1.3B-base | x | x | x | -| deepseek-ai_DeepSeek-V2-Lite | x | x | x | -| deepseek-coder-6.7B-instruct | x | x | x | -| DeepSeek-V2-Lite-64x1.5B | x | x | x | -| falcon-7b-instruct | √ | √ | √ | -| flan-t5-large | √ | √ | √ | -| gemma-2-9b-it | √ | √ | √ | -| glm-4-9B | x | x | x | -| gpt2 | √ | √ | √ | -| Gpt2-163M | √ | √ | √ | -| granite-3B-code-instruct | √ | √ | √ | +| Llama-2 | √ | √ | √ | +| Llama-3 | √ | √ | √ | +| Mistral-7B | √ | √ | √ | +| Mistral MOE | √ | √ | √ | +| DBRX | - | - | - | +| Falcon | √ | √ | √ | +| Chinese LLaMA/Alpaca | √ | √ | √ | +| Vigogne(French) | √ | √ | √ | +| BERT | x | x | x | +| Koala | √ | √ | √ | +| Baichuan | √ | √ | √ | +| Aquila 1 & 2 | √ | √ | √ | +| Starcoder models | √ | √ | √ | +| Refact | √ | √ | √ | +| MPT | √ | √ | √ | +| Bloom | √ | √ | √ | +| Yi models | √ | √ | √ | +| stablelm models | √ | √ | √ | +| DeepSeek models | x | x | x | +| Qwen models | √ | √ | √ | +| PLaMo-13B | √ | √ | √ | +| Phi models | √ | √ | √ | +| PhiMoE | √ | √ | √ | +| GPT-2 | √ | √ | √ | +| Orion | √ | √ | √ | +| InternlLM2 | √ | √ | √ | +| CodeShell | √ | √ | √ | +| Gemma | √ | √ | √ | +| Mamba | √ | √ | √ | +| Xverse | √ | √ | √ | +| command-r models | √ | √ | √ | +| Grok-1 | - | - | - | +| SEA-LION | √ | √ | √ | | GritLM-7B | √ | √ | √ | -| internlm2_5-7b-chat | √ | √ | √ | -| koala-7B-HF | √ | √ | √ | -| Llama-2-7b-chat-hf | √ | √ | √ | -| Llama-3-Smaug-8B | √ | √ | √ | -| Llama2-Chinese-7b-Chat | √ | √ | √ | -| Llama3-8B | √ | √ | √ | -| Llama3-8b-chinese | √ | √ | √ | -| mamba-130m-hf | √ | √ | √ | -| Mistral-7B-Instruct-v0.2 | √ | √ | √ | -| Mixtral-8x7B-Instruct-v0.1 | x | √ | √ | -| mpt-7B | √ | √ | √ | -| OLMo-1B-hf | √ | √ | √ | -| OpenELM-3B-Instruct | √ | √ | √ | -| Orion-14b-base | √ | √ | √ | -| phi1 | x | x | x | -| phi2 | x | x | x | -| Phi-3-mini-4k-instruct | √ | √ | √ | -| plamo-13b | √ | √ | √ | -| pythia-70M | x | x | x | -| Qwen-7B | √ | √ | √ | -| Qwen2-1.5B-Instruct | √ | x | √ | -| Refact-1_6B-fim | √ | √ | √ | -| SmolLM-135M | √ | √ | √ | -| stablelm-zephyr | x | x | x | -| stablelm-2-zephyr-1_6b | x | x | x | -| starcoderbase-1b | √ | √ | √ | -| starcoder2-3b | √ | √ | √ | -| vigogne-7b-chat | √ | √ | √ | -| xverse-7b-chat | √ | √ | √ | -| Yi-6b-Chat | √ | √ | √ | +| OLMo | √ | √ | √ | +| OLMo 2 | √ | √ | √ | +| OLMoE | √ | √ | √ | +| Granite models | √ | √ | √ | +| GPT-NeoX | √ | √ | √ | +| Pythia | √ | √ | √ | +| Snowflake-Arctic MoE | - | - | - | +| Smaug | √ | √ | √ | +| Poro 34B | √ | √ | √ | +| Bitnet b1.58 models | √ | x | x | +| Flan-T5 | √ | √ | √ | +| Open Elm models | x | √ | √ | +| chatGLM3-6B + ChatGLM4-9b + GLMEdge-1.5b + GLMEdge-4b | √ | √ | √ | +| GLM-4-0414 | √ | √ | √ | +| SmolLM | √ | √ | √ | +| EXAONE-3.0-7.8B-Instruct | √ | √ | √ | +| FalconMamba Models | √ | √ | √ | +| Jais Models | - | x | x | +| Bielik-11B-v2.3 | √ | √ | √ | +| RWKV-6 | - | √ | √ | +| QRWKV-6 | √ | √ | √ | +| GigaChat-20B-A3B | x | x | x | +| Trillion-7B-preview | √ | √ | √ | +| Ling models | √ | √ | √ | + + +**Multimodal** +| Model Name | FP16 | Q4_0 | Q8_0 | +|:----------------------------|:-----:|:----:|:----:| +| LLaVA 1.5 models, LLaVA 1.6 models | x | x | x | +| BakLLaVA | √ | √ | √ | +| Obsidian | √ | - | - | +| ShareGPT4V | x | - | - | +| MobileVLM 1.7B/3B models | - | - | - | +| Yi-VL | - | - | - | +| Mini CPM | √ | √ | √ | +| Moondream | √ | √ | √ | +| Bunny | √ | - | - | +| GLM-EDGE | √ | √ | √ | +| Qwen2-VL | √ | √ | √ | @@ -258,6 +281,34 @@ cmake --build build --config release ### **GitHub contribution**: Please add the **[CANN]** prefix/tag in issues/PRs titles to help the CANN-team check/address them without delay. +## Updates +### Basic Flash Attention Support +The basic FA kernel with aclnnops has been added in aclnn_ops.cpp. +Currently, the FA only supports the cases with FP16 KV tensors and NO logit softcap. +Since the aclnn interface for flash attention cannot support the logit softcap, we will only update the quantized version in the future. + +Authors from Peking University: Bizhao Shi (bshi@pku.edu.cn), Yuxin Yang (yxyang@pku.edu.cn), Ruiyang Ma (ruiyang@stu.pku.edu.cn), and Guojie Luo (gluo@pku.edu.cn). + +We would like to thank Tuo Dai, Shanni Li, and all of the project maintainers from Huawei Technologies Co., Ltd for their help during the code development and pull request. + +## Environment variable setup + +### GGML_CANN_ASYNC_MODE + +Enables asynchronous operator submission. Disabled by default. + +### GGML_CANN_MEM_POOL + +Specifies the memory pool management strategy: + +- vmm: Utilizes a virtual memory manager pool. If hardware support for VMM is unavailable, falls back to the legacy (leg) memory pool. + +- prio: Employs a priority queue-based memory pool management. +- leg: Uses a fixed-size buffer pool. + +### GGML_CANN_DISABLE_BUF_POOL_CLEAN + +Controls automatic cleanup of the memory pool. This option is only effective when using the prio or leg memory pool strategies. ## TODO - Support more models and data types. diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md index 20aefec2f..249e73451 100644 --- a/docs/backend/SYCL.md +++ b/docs/backend/SYCL.md @@ -17,25 +17,25 @@ **SYCL** is a high-level parallel programming model designed to improve developers productivity writing code across various hardware accelerators such as CPUs, GPUs, and FPGAs. It is a single-source language designed for heterogeneous computing and based on standard C++17. -**oneAPI** is an open ecosystem and a standard-based specification, supporting multiple architectures including but not limited to intel CPUs, GPUs and FPGAs. The key components of the oneAPI ecosystem include: +**oneAPI** is an open ecosystem and a standard-based specification, supporting multiple architectures including but not limited to Intel CPUs, GPUs and FPGAs. The key components of the oneAPI ecosystem include: - **DPCPP** *(Data Parallel C++)*: The primary oneAPI SYCL implementation, which includes the icpx/icx Compilers. - **oneAPI Libraries**: A set of highly optimized libraries targeting multiple domains *(e.g. Intel oneMKL, oneMath and oneDNN)*. -- **oneAPI LevelZero**: A high performance low level interface for fine-grained control over intel iGPUs and dGPUs. +- **oneAPI LevelZero**: A high performance low level interface for fine-grained control over Intel iGPUs and dGPUs. - **Nvidia & AMD Plugins**: These are plugins extending oneAPI's DPCPP support to SYCL on Nvidia and AMD GPU targets. ### Llama.cpp + SYCL -The llama.cpp SYCL backend is designed to support **Intel GPU** firstly. Based on the cross-platform feature of SYCL, it also supports other vendor GPUs: Nvidia and AMD. +The llama.cpp SYCL backend is primarily designed for **Intel GPUs**. +SYCL cross-platform capabilities enable support for Nvidia GPUs as well, with limited support for AMD. ## Recommended Release -The SYCL backend would be broken by some PRs due to no online CI. - -The following release is verified with good quality: +The following releases are verified and recommended: |Commit ID|Tag|Release|Verified Platform| Update date| |-|-|-|-|-| +|24e86cae7219b0f3ede1d5abdf5bf3ad515cccb8|b5377 |[llama-b5377-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b5377/llama-b5377-bin-win-sycl-x64.zip) |ArcB580/Linux/oneAPI 2025.1
LNL Arc GPU/Windows 11/oneAPI 2025.1.1|2025-05-15| |3bcd40b3c593d14261fb2abfabad3c0fb5b9e318|b4040 |[llama-b4040-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b4040/llama-b4040-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1
MTL Arc GPU/Windows 11/oneAPI 2024.1| 2024-11-19| |fb76ec31a9914b7761c1727303ab30380fd4f05c|b3038 |[llama-b3038-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b3038/llama-b3038-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1
MTL Arc GPU/Windows 11/oneAPI 2024.1|| @@ -106,15 +106,14 @@ SYCL backend supports Intel GPU Family: |-------------------------------|---------|---------------------------------------| | Intel Data Center Max Series | Support | Max 1550, 1100 | | Intel Data Center Flex Series | Support | Flex 170 | -| Intel Arc Series | Support | Arc 770, 730M, Arc A750 | -| Intel built-in Arc GPU | Support | built-in Arc GPU in Meteor Lake, Arrow Lake | -| Intel iGPU | Support | iGPU in 13700k,iGPU in 13400, i5-1250P, i7-1260P, i7-1165G7 | +| Intel Arc Series | Support | Arc 770, 730M, Arc A750, B580 | +| Intel built-in Arc GPU | Support | built-in Arc GPU in Meteor Lake, Arrow Lake, Lunar Lake | +| Intel iGPU | Support | iGPU in 13700k, 13400, i5-1250P, i7-1260P, i7-1165G7 | *Notes:* - **Memory** - The device memory is a limitation when running a large model. The loaded model size, *`llm_load_tensors: buffer_size`*, is displayed in the log when running `./bin/llama-cli`. - - Please make sure the GPU shared memory from the host is large enough to account for the model's size. For e.g. the *llama-2-7b.Q4_0* requires at least 8.0GB for integrated GPU and 4.0GB for discrete GPU. - **Execution Unit (EU)** @@ -138,9 +137,11 @@ Note: AMD GPU support is highly experimental and is incompatible with F16. Additionally, it only supports GPUs with a sub_group_size (warp size) of 32. ## Docker -The docker build option is currently limited to *intel GPU* targets. + +The docker build option is currently limited to *Intel GPU* targets. ### Build image + ```sh # Using FP16 docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=ON" --target light -f .devops/intel.Dockerfile . @@ -148,9 +149,10 @@ docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=ON" --target light -f *Notes*: -To build in default FP32 *(Slower than FP16 alternative)*, you can remove the `--build-arg="GGML_SYCL_F16=ON"` argument from the previous command. +To build in default FP32 *(Slower than FP16 alternative)*, set `--build-arg="GGML_SYCL_F16=OFF"` in the previous command. You can also use the `.devops/llama-server-intel.Dockerfile`, which builds the *"server"* alternative. +Check the [documentation for Docker](../docker.md) to see the available images. ### Run container @@ -250,7 +252,7 @@ sycl-ls - **Intel GPU** -When targeting an intel GPU, the user should expect one or more level-zero devices among the available SYCL devices. Please make sure that at least one GPU is present, for instance [`level_zero:gpu`] in the sample output below: +When targeting an intel GPU, the user should expect one or more devices among the available SYCL devices. Please make sure that at least one GPU is present via `sycl-ls`, for instance `[level_zero:gpu]` in the sample output below: ``` [opencl:acc][opencl:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2023.16.10.0.17_160000] @@ -282,7 +284,7 @@ For AMD GPUs we should expect at least one SYCL-HIP device [`hip:gpu`]: #### Intel GPU -``` +```sh ./examples/sycl/build.sh ``` @@ -351,7 +353,7 @@ cmake --build build --config Release -j -v #### Retrieve and prepare model -You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model prepration, or simply download [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) model as example. +You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf). ##### Check device @@ -398,11 +400,15 @@ Choose one of following methods to run. ```sh ./examples/sycl/run-llama2.sh 0 +# OR +./examples/sycl/run-llama3.sh 0 ``` - Use multiple devices: ```sh ./examples/sycl/run-llama2.sh +# OR +./examples/sycl/run-llama3.sh ``` 2. Command line @@ -425,13 +431,13 @@ Examples: - Use device 0: ```sh -ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -no-cnv -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 -sm none -mg 0 +ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -no-cnv -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 99 -sm none -mg 0 ``` - Use multiple devices: ```sh -ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -no-cnv -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 33 -sm layer +ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -no-cnv -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 99 -sm layer ``` *Notes:* @@ -452,7 +458,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512 1. Install GPU driver -Intel GPU drivers instructions guide and download page can be found here: [Get intel GPU Drivers](https://www.intel.com/content/www/us/en/products/docs/discrete-gpus/arc/software/drivers.html). +Intel GPU drivers instructions guide and download page can be found here: [Get Intel GPU Drivers](https://www.intel.com/content/www/us/en/products/docs/discrete-gpus/arc/software/drivers.html). 2. Install Visual Studio @@ -629,7 +635,7 @@ Once it is completed, final results will be in **build/Release/bin** #### Retrieve and prepare model -You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model prepration, or simply download [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) model as example. +You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf). ##### Check device @@ -648,7 +654,7 @@ Similar to the native `sycl-ls`, available SYCL devices can be queried as follow build\bin\llama-ls-sycl-device.exe ``` -This command will only display the selected backend that is supported by SYCL. The default backend is level_zero. For example, in a system with 2 *intel GPU* it would look like the following: +This command will only display the selected backend that is supported by SYCL. The default backend is level_zero. For example, in a system with 2 *Intel GPU* it would look like the following: ``` found 2 SYCL devices: | | | |Compute |Max compute|Max work|Max sub| | @@ -658,13 +664,14 @@ found 2 SYCL devices: | 1|[level_zero:gpu:1]| Intel(R) UHD Graphics 770| 1.3| 32| 512| 32| 53651849216| ``` + #### Choose level-zero devices |Chosen Device ID|Setting| |-|-| -|0|`set ONEAPI_DEVICE_SELECTOR="level_zero:1"` or no action| +|0|Default option. You may also want to `set ONEAPI_DEVICE_SELECTOR="level_zero:0"`| |1|`set ONEAPI_DEVICE_SELECTOR="level_zero:1"`| -|0 & 1|`set ONEAPI_DEVICE_SELECTOR="level_zero:0;level_zero:1"`| +|0 & 1|`set ONEAPI_DEVICE_SELECTOR="level_zero:0;level_zero:1"` or `set ONEAPI_DEVICE_SELECTOR="level_zero:*"`| #### Execute @@ -673,7 +680,13 @@ Choose one of following methods to run. 1. Script ``` -examples\sycl\win-run-llama2.bat +examples\sycl\win-run-llama-2.bat +``` + +or + +``` +examples\sycl\win-run-llama-3.bat ``` 2. Command line @@ -697,13 +710,13 @@ Examples: - Use device 0: ``` -build\bin\llama-cli.exe -no-cnv -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 33 -s 0 -sm none -mg 0 +build\bin\llama-cli.exe -no-cnv -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 99 -sm none -mg 0 ``` - Use multiple devices: ``` -build\bin\llama-cli.exe -no-cnv -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 33 -s 0 -sm layer +build\bin\llama-cli.exe -no-cnv -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 99 -sm layer ``` @@ -714,7 +727,9 @@ Note: ```sh detect 1 SYCL GPUs: [0] with top Max compute units:512 ``` + Or + ```sh use 1 SYCL GPUs: [0] with Max compute units:512 ``` @@ -726,14 +741,17 @@ use 1 SYCL GPUs: [0] with Max compute units:512 | Name | Value | Function | |--------------------|---------------------------------------|---------------------------------------------| -| GGML_SYCL | ON (mandatory) | Enable build with SYCL code path.
FP32 path - recommended for better perforemance than FP16 on quantized model| +| GGML_SYCL | ON (mandatory) | Enable build with SYCL code path. | | GGML_SYCL_TARGET | INTEL *(default)* \| NVIDIA \| AMD | Set the SYCL target device type. | | GGML_SYCL_DEVICE_ARCH | Optional (except for AMD) | Set the SYCL device architecture, optional except for AMD. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. | -| GGML_SYCL_F16 | OFF *(default)* \|ON *(optional)* | Enable FP16 build with SYCL code path. | +| GGML_SYCL_F16 | OFF *(default)* \|ON *(optional)* | Enable FP16 build with SYCL code path. (1.) | | GGML_SYCL_GRAPH | ON *(default)* \|OFF *(Optional)* | Enable build with [SYCL Graph extension](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc). | +| GGML_SYCL_DNN | ON *(default)* \|OFF *(Optional)* | Enable build with oneDNN. | | CMAKE_C_COMPILER | `icx` *(Linux)*, `icx/cl` *(Windows)* | Set `icx` compiler for SYCL code path. | | CMAKE_CXX_COMPILER | `icpx` *(Linux)*, `icx` *(Windows)* | Set `icpx/icx` compiler for SYCL code path. | +1. FP16 is recommended for better prompt processing performance on quantized models. Performance is equivalent in text generation but set `GGML_SYCL_F16=OFF` if you are experiencing issues with FP16 builds. + #### Runtime | Name | Value | Function | @@ -741,6 +759,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512 | GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG | | GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features based on Intel GPU type, to compare the performance increase | | GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. | +| GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. | | ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.
Recommended to use when --split-mode = layer | @@ -750,7 +769,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512 ## Q&A -- Error: `error while loading shared libraries: libsycl.so.7: cannot open shared object file: No such file or directory`. +- Error: `error while loading shared libraries: libsycl.so: cannot open shared object file: No such file or directory`. - Potential cause: Unavailable oneAPI installation or not set ENV variables. - Solution: Install *oneAPI base toolkit* and enable its ENV through: `source /opt/intel/oneapi/setvars.sh`. @@ -779,18 +798,18 @@ use 1 SYCL GPUs: [0] with Max compute units:512 It's same for other projects including llama.cpp SYCL backend. -- Meet issue: `Native API failed. Native API returns: -6 (PI_ERROR_OUT_OF_HOST_MEMORY) -6 (PI_ERROR_OUT_OF_HOST_MEMORY) -999 (UNKNOWN PI error)` or `failed to allocate SYCL0 buffer` +- `Native API failed. Native API returns: 39 (UR_RESULT_ERROR_OUT_OF_DEVICE_MEMORY)`, `ggml_backend_sycl_buffer_type_alloc_buffer: can't allocate 3503030272 Bytes of memory on device`, or `failed to allocate SYCL0 buffer` - Device Memory is not enough. + You are running out of Device Memory. |Reason|Solution| |-|-| - |Default Context is too big. It leads to more memory usage.|Set `-c 8192` or smaller value.| - |Model is big and require more memory than device's.|Choose smaller quantized model, like Q5 -> Q4;
Use more than one devices to load model.| + | The default context is too big. It leads to excessive memory usage.|Set `-c 8192` or a smaller value.| + | The model is too big and requires more memory than what is available.|Choose a smaller model or change to a smaller quantization, like Q5 -> Q4;
Alternatively, use more than one device to load model.| ### **GitHub contribution**: -Please add the **[SYCL]** prefix/tag in issues/PRs titles to help the SYCL-team check/address them without delay. +Please add the `SYCL :` prefix/tag in issues/PRs titles to help the SYCL contributors to check/address them without delay. ## TODO -- NA +- Review ZES_ENABLE_SYSMAN: https://github.com/intel/compute-runtime/blob/master/programmers-guide/SYSMAN.md#support-and-limitations diff --git a/docs/build-s390x.md b/docs/build-s390x.md new file mode 100644 index 000000000..f44038c58 --- /dev/null +++ b/docs/build-s390x.md @@ -0,0 +1,157 @@ +> [!IMPORTANT] +> This build documentation is specific only to IBM Z & LinuxONE mainframes (s390x). You can find the build documentation for other architectures: [build.md](build.md). + +# Build llama.cpp locally (for s390x) + +The main product of this project is the `llama` library. Its C-style interface can be found in [include/llama.h](../include/llama.h). + +The project also includes many example programs and tools using the `llama` library. The examples range from simple, minimal code snippets to sophisticated sub-projects such as an OpenAI-compatible HTTP server. + +**To get the code:** + +```bash +git clone https://github.com/ggml-org/llama.cpp +cd llama.cpp +``` + +## CPU Build with BLAS + +Building llama.cpp with BLAS support is highly recommended as it has shown to provide performance improvements. + +```bash +cmake -S . -B build \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_BLAS=ON \ + -DGGML_BLAS_VENDOR=OpenBLAS + +cmake --build build --config Release -j $(nproc) +``` + +**Notes**: +- For faster repeated compilation, install [ccache](https://ccache.dev/) +- By default, VXE/VXE2 is enabled. To disable it (not recommended): + + ```bash + cmake -S . -B build \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_BLAS=ON \ + -DGGML_BLAS_VENDOR=OpenBLAS \ + -DGGML_VXE=OFF + + cmake --build build --config Release -j $(nproc) + ``` + +- For debug builds: + + ```bash + cmake -S . -B build \ + -DCMAKE_BUILD_TYPE=Debug \ + -DGGML_BLAS=ON \ + -DGGML_BLAS_VENDOR=OpenBLAS + + cmake --build build --config Debug -j $(nproc) + ``` + +- For static builds, add `-DBUILD_SHARED_LIBS=OFF`: + + ```bash + cmake -S . -B build \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_BLAS=ON \ + -DGGML_BLAS_VENDOR=OpenBLAS \ + -DBUILD_SHARED_LIBS=OFF + + cmake --build build --config Release -j $(nproc) + ``` + +## Getting GGUF Models + +All models need to be converted to Big-Endian. You can achieve this in three cases: + +1. **Use pre-converted models verified for use on IBM Z & LinuxONE (easiest)** + + You can find popular models pre-converted and verified at [s390x Ready Models](hf.co/collections/taronaeo/s390x-ready-models-672765393af438d0ccb72a08). + + These models and their respective tokenizers are verified to run correctly on IBM Z & LinuxONE. + +2. **Convert safetensors model to GGUF Big-Endian directly (recommended)** + + ```bash + python3 convert_hf_to_gguf.py \ + --outfile model-name-be.f16.gguf \ + --outtype f16 \ + --bigendian \ + model-directory/ + ``` + + For example, + + ```bash + python3 convert_hf_to_gguf.py \ + --outfile granite-3.3-2b-instruct-be.f16.gguf \ + --outtype f16 \ + --bigendian \ + granite-3.3-2b-instruct/ + ``` + +3. **Convert existing GGUF Little-Endian model to Big-Endian** + + ```bash + python3 gguf-py/gguf/scripts/gguf_convert_endian.py model-name.f16.gguf BIG + ``` + + For example, + ```bash + python3 gguf-py/gguf/scripts/gguf_convert_endian.py granite-3.3-2b-instruct-le.f16.gguf BIG + mv granite-3.3-2b-instruct-le.f16.gguf granite-3.3-2b-instruct-be.f16.gguf + ``` + + **Notes:** + - The GGUF endian conversion script may not support all data types at the moment and may fail for some models/quantizations. When that happens, please try manually converting the safetensors model to GGUF Big-Endian via Step 2. + +## IBM Accelerators + +### 1. SIMD Acceleration + +Only available in IBM z15 or later system with the `-DGGML_VXE=ON` (turned on by default) compile flag. No hardware acceleration is possible with llama.cpp with older systems, such as IBM z14 or EC13. In such systems, the APIs can still run but will use a scalar implementation. + +### 2. zDNN Accelerator + +*Only available in IBM z16 or later system. No direction at the moment.* + +### 3. Spyre Accelerator + +*No direction at the moment.* + +## Performance Tuning + +### 1. Virtualization Setup + +It is strongly recommended to use only LPAR (Type-1) virtualization to get the most performance. + +Note: Type-2 virtualization is not supported at the moment, while you can get it running, the performance will not be the best. + +### 2. IFL (Core) Count + +It is recommended to allocate a minimum of 8 shared IFLs assigned to the LPAR. Increasing the IFL count past 8 shared IFLs will only improve Prompt Processing performance but not Token Generation. + +Note: IFL count does not equate to vCPU count. + +### 3. SMT vs NOSMT (Simultaneous Multithreading) + +It is strongly recommended to disable SMT via the kernel boot parameters as it negatively affects performance. Please refer to your Linux distribution's guide on disabling SMT via kernel boot parameters. + +### 4. BLAS vs NOBLAS + +IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongly recommended to use BLAS. + +## Getting Help on IBM Z & LinuxONE + +1. **Bugs, Feature Requests** + + Please file an issue in llama.cpp and ensure that the title contains "s390x". + +2. **Other Questions** + + Please reach out directly to [aionz@us.ibm.com](mailto:aionz@us.ibm.com). + diff --git a/docs/build.md b/docs/build.md index 3f1b04399..680b0d839 100644 --- a/docs/build.md +++ b/docs/build.md @@ -1,5 +1,9 @@ # Build llama.cpp locally +The main product of this project is the `llama` library. Its C-style interface can be found in [include/llama.h](include/llama.h). + +The project also includes many example programs and tools using the `llama` library. The examples range from simple, minimal code snippets to sophisticated sub-projects such as an OpenAI-compatible HTTP server. + **To get the Code:** ```bash @@ -63,6 +67,7 @@ cmake --build build --config Release cmake --preset x64-windows-llvm-release cmake --build build-x64-windows-llvm-release ``` +- Curl usage is enabled by default and can be turned off with `-DLLAMA_CURL=OFF`. Otherwise you need to install development libraries for libcurl. ## BLAS Build @@ -259,8 +264,6 @@ You can download it from your Linux distro's package manager or from here: [ROCm cmake -S . -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \ && cmake --build build --config Release -- -j 16 ``` - On Linux it is also possible to use unified memory architecture (UMA) to share main memory between the CPU and integrated GPU by setting `-DGGML_HIP_UMA=ON`. - However, this hurts performance for non-integrated GPUs (but enables working with integrated GPUs). To enhance flash attention performance on RDNA3+ or CDNA architectures, you can utilize the rocWMMA library by enabling the `-DGGML_HIP_ROCWMMA_FATTN=ON` option. This requires rocWMMA headers to be installed on the build system. @@ -296,6 +299,10 @@ You can download it from your Linux distro's package manager or from here: [ROCm The environment variable [`HIP_VISIBLE_DEVICES`](https://rocm.docs.amd.com/en/latest/understand/gpu_isolation.html#hip-visible-devices) can be used to specify which GPU(s) will be used. If your GPU is not officially supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 (e.g. gfx1030, gfx1031, or gfx1035) or 11.0.0 on RDNA3. +### Unified Memory + +On Linux it is possible to use unified memory architecture (UMA) to share main memory between the CPU and integrated GPU by setting environment variable `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1`. However, this hurts performance for non-integrated GPUs (but enables working with integrated GPUs). + ## Vulkan **Windows** diff --git a/docs/development/HOWTO-add-model.md b/docs/development/HOWTO-add-model.md index 78c6f7607..7f71e0247 100644 --- a/docs/development/HOWTO-add-model.md +++ b/docs/development/HOWTO-add-model.md @@ -9,10 +9,10 @@ Adding a model requires few steps: After following these steps, you can open PR. Also, it is important to check that the examples and main ggml backends (CUDA, METAL, CPU) are working with the new architecture, especially: -- [main](/examples/main/) -- [imatrix](/examples/imatrix/) -- [quantize](/examples/quantize/) -- [server](/examples/server/) +- [main](/tools/main/) +- [imatrix](/tools/imatrix/) +- [quantize](/tools/quantize/) +- [server](/tools/server/) ### 1. Convert the model to GGUF diff --git a/docs/docker.md b/docs/docker.md index 343146dbd..f8f0573c1 100644 --- a/docs/docker.md +++ b/docs/docker.md @@ -22,6 +22,9 @@ Additionally, there the following images, similar to the above: - `ghcr.io/ggml-org/llama.cpp:full-musa`: Same as `full` but compiled with MUSA support. (platforms: `linux/amd64`) - `ghcr.io/ggml-org/llama.cpp:light-musa`: Same as `light` but compiled with MUSA support. (platforms: `linux/amd64`) - `ghcr.io/ggml-org/llama.cpp:server-musa`: Same as `server` but compiled with MUSA support. (platforms: `linux/amd64`) +- `ghcr.io/ggml-org/llama.cpp:full-intel`: Same as `full` but compiled with SYCL support. (platforms: `linux/amd64`) +- `ghcr.io/ggml-org/llama.cpp:light-intel`: Same as `light` but compiled with SYCL support. (platforms: `linux/amd64`) +- `ghcr.io/ggml-org/llama.cpp:server-intel`: Same as `server` but compiled with SYCL support. (platforms: `linux/amd64`) The GPU enabled images are not currently tested by CI beyond being built. They are not built with any variation from the ones in the Dockerfiles defined in [.devops/](../.devops/) and the GitHub Action defined in [.github/workflows/docker.yml](../.github/workflows/docker.yml). If you need different settings (for example, a different CUDA, ROCm or MUSA library, you'll need to build the images locally for now). @@ -104,7 +107,7 @@ You may want to pass in some different `ARGS`, depending on the MUSA environment The defaults are: -- `MUSA_VERSION` set to `rc3.1.1` +- `MUSA_VERSION` set to `rc4.0.1` The resulting images, are essentially the same as the non-MUSA images: diff --git a/docs/function-calling.md b/docs/function-calling.md index c3873c3fa..37eacaf31 100644 --- a/docs/function-calling.md +++ b/docs/function-calling.md @@ -2,7 +2,6 @@ [chat.h](../common/chat.h) (https://github.com/ggml-org/llama.cpp/pull/9639) adds support for [OpenAI-style function calling](https://platform.openai.com/docs/guides/function-calling) and is used in: - `llama-server` when started w/ `--jinja` flag -- `llama-cli` (WIP: https://github.com/ggml-org/llama.cpp/pull/11556) ## Universal support w/ Native & Generic handlers @@ -12,7 +11,7 @@ Function calling is supported for all models (see https://github.com/ggml-org/ll - Llama 3.1 / 3.3 (including builtin tools support - tool names for `wolfram_alpha`, `web_search` / `brave_search`, `code_interpreter`), Llama 3.2 - Functionary v3.1 / v3.2 - Hermes 2/3, Qwen 2.5 - - Qwen 2.5 Coder (WIP: https://github.com/ggml-org/llama.cpp/pull/12034) + - Qwen 2.5 Coder - Mistral Nemo - Firefunction v2 - Command R7B @@ -325,36 +324,65 @@ To get the official template from original HuggingFace repos, you can use [scrip > [!TIP] > If there is no official `tool_use` Jinja template, you may want to set `--chat-template chatml` to use a default that works with many models (YMMV!), or write your own (e.g. we provide a custom [llama-cpp-deepseek-r1.jinja](../models/templates/llama-cpp-deepseek-r1.jinja) for DeepSeek R1 distills) +> [!CAUTION] +> Beware of extreme KV quantizations (e.g. `-ctk q4_0`), they can substantially degrade the model's tool calling performance. + Test in CLI (or with any library / software that can use OpenAI-compatible API backends): ```bash curl http://localhost:8080/v1/chat/completions -d '{ -"model": "gpt-3.5-turbo", -"tools": [ - { - "type":"function", - "function":{ - "name":"python", - "description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", - "parameters":{ - "type":"object", - "properties":{ - "code":{ - "type":"string", - "description":"The code to run in the ipython interpreter." + "model": "gpt-3.5-turbo", + "tools": [ + { + "type":"function", + "function":{ + "name":"python", + "description":"Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.", + "parameters":{ + "type":"object", + "properties":{ + "code":{ + "type":"string", + "description":"The code to run in the ipython interpreter." + } + }, + "required":["code"] } - }, - "required":["code"] } - } - } -], -"messages": [ - { - "role": "user", - "content": "Print a hello world message with python." - } -] + } + ], + "messages": [ + { + "role": "user", + "content": "Print a hello world message with python." + } + ] +}' + + +curl http://localhost:8080/v1/chat/completions -d '{ + "model": "gpt-3.5-turbo", + "messages": [ + {"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."}, + {"role": "user", "content": "What is the weather in Istanbul?"} + ], + "tools": [{ + "type":"function", + "function":{ + "name":"get_current_weather", + "description":"Get the current weather in a given location", + "parameters":{ + "type":"object", + "properties":{ + "location":{ + "type":"string", + "description":"The city and country/state, e.g. `San Francisco, CA`, or `Paris, France`" + } + }, + "required":["location"] + } + } + }] }' ``` diff --git a/docs/install.md b/docs/install.md index 4971c1828..7200bf9b7 100644 --- a/docs/install.md +++ b/docs/install.md @@ -1,28 +1,42 @@ # Install pre-built version of llama.cpp -## Homebrew +| Install via | Windows | Mac | Linux | +|-------------|---------|-----|-------| +| Winget | ✅ | | | +| Homebrew | | ✅ | ✅ | +| MacPorts | | ✅ | | +| Nix | | ✅ | ✅ | -On Mac and Linux, the homebrew package manager can be used via +## Winget (Windows) + +```sh +winget install llama.cpp +``` + +The package is automatically updated with new `llama.cpp` releases. More info: https://github.com/ggml-org/llama.cpp/issues/8188 + +## Homebrew (Mac and Linux) ```sh brew install llama.cpp ``` + The formula is automatically updated with new `llama.cpp` releases. More info: https://github.com/ggml-org/llama.cpp/discussions/7668 -## MacPorts +## MacPorts (Mac) ```sh sudo port install llama.cpp ``` -see also: https://ports.macports.org/port/llama.cpp/details/ -## Nix +See also: https://ports.macports.org/port/llama.cpp/details/ -On Mac and Linux, the Nix package manager can be used via +## Nix (Mac and Linux) ```sh nix profile install nixpkgs#llama-cpp ``` + For flake enabled installs. Or @@ -34,13 +48,3 @@ nix-env --file '' --install --attr llama-cpp For non-flake enabled installs. This expression is automatically updated within the [nixpkgs repo](https://github.com/NixOS/nixpkgs/blob/nixos-24.05/pkgs/by-name/ll/llama-cpp/package.nix#L164). - -## Flox - -On Mac and Linux, Flox can be used to install llama.cpp within a Flox environment via - -```sh -flox install llama-cpp -``` - -Flox follows the nixpkgs build of llama.cpp. diff --git a/docs/multimodal.md b/docs/multimodal.md new file mode 100644 index 000000000..edbd081df --- /dev/null +++ b/docs/multimodal.md @@ -0,0 +1,113 @@ +# Multimodal + +llama.cpp supports multimodal input via `libmtmd`. Currently, there are 2 tools support this feature: +- [llama-mtmd-cli](../tools/mtmd/README.md) +- [llama-server](../tools/server/README.md) via OpenAI-compatible `/chat/completions` API + +Currently, we support **image** and **audio** input. Audio is highly experimental and may have reduced quality. + +To enable it, you can use one of the 2 methods below: + +- Use `-hf` option with a supported model (see a list of pre-quantized model below) + - To load a model using `-hf` while disabling multimodal, use `--no-mmproj` + - To load a model using `-hf` while using a custom mmproj file, use `--mmproj local_file.gguf` +- Use `-m model.gguf` option with `--mmproj file.gguf` to specify text and multimodal projector respectively + +By default, multimodal projector will be offloaded to GPU. To disable this, add `--no-mmproj-offload` + +For example: + +```sh +# simple usage with CLI +llama-mtmd-cli -hf ggml-org/gemma-3-4b-it-GGUF + +# simple usage with server +llama-server -hf ggml-org/gemma-3-4b-it-GGUF + +# using local file +llama-server -m gemma-3-4b-it-Q4_K_M.gguf --mmproj mmproj-gemma-3-4b-it-Q4_K_M.gguf + +# no GPU offload +llama-server -hf ggml-org/gemma-3-4b-it-GGUF --no-mmproj-offload +``` + +## Pre-quantized models + +These are ready-to-use models, most of them come with `Q4_K_M` quantization by default. They can be found at the Hugging Face page of the ggml-org: https://huggingface.co/collections/ggml-org/multimodal-ggufs-68244e01ff1f39e5bebeeedc + +Replaces the `(tool_name)` with the name of binary you want to use. For example, `llama-mtmd-cli` or `llama-server` + +NOTE: some models may require large context window, for example: `-c 8192` + +**Vision models**: + +```sh +# Gemma 3 +(tool_name) -hf ggml-org/gemma-3-4b-it-GGUF +(tool_name) -hf ggml-org/gemma-3-12b-it-GGUF +(tool_name) -hf ggml-org/gemma-3-27b-it-GGUF + +# SmolVLM +(tool_name) -hf ggml-org/SmolVLM-Instruct-GGUF +(tool_name) -hf ggml-org/SmolVLM-256M-Instruct-GGUF +(tool_name) -hf ggml-org/SmolVLM-500M-Instruct-GGUF +(tool_name) -hf ggml-org/SmolVLM2-2.2B-Instruct-GGUF +(tool_name) -hf ggml-org/SmolVLM2-256M-Video-Instruct-GGUF +(tool_name) -hf ggml-org/SmolVLM2-500M-Video-Instruct-GGUF + +# Pixtral 12B +(tool_name) -hf ggml-org/pixtral-12b-GGUF + +# Qwen 2 VL +(tool_name) -hf ggml-org/Qwen2-VL-2B-Instruct-GGUF +(tool_name) -hf ggml-org/Qwen2-VL-7B-Instruct-GGUF + +# Qwen 2.5 VL +(tool_name) -hf ggml-org/Qwen2.5-VL-3B-Instruct-GGUF +(tool_name) -hf ggml-org/Qwen2.5-VL-7B-Instruct-GGUF +(tool_name) -hf ggml-org/Qwen2.5-VL-32B-Instruct-GGUF +(tool_name) -hf ggml-org/Qwen2.5-VL-72B-Instruct-GGUF + +# Mistral Small 3.1 24B (IQ2_M quantization) +(tool_name) -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF + +# InternVL 2.5 and 3 +(tool_name) -hf ggml-org/InternVL2_5-1B-GGUF +(tool_name) -hf ggml-org/InternVL2_5-4B-GGUF +(tool_name) -hf ggml-org/InternVL3-1B-Instruct-GGUF +(tool_name) -hf ggml-org/InternVL3-2B-Instruct-GGUF +(tool_name) -hf ggml-org/InternVL3-8B-Instruct-GGUF +(tool_name) -hf ggml-org/InternVL3-14B-Instruct-GGUF + +# Llama 4 Scout +(tool_name) -hf ggml-org/Llama-4-Scout-17B-16E-Instruct-GGUF + +# Moondream2 20250414 version +(tool_name) -hf ggml-org/moondream2-20250414-GGUF + +``` + +**Audio models**: + +```sh +# Ultravox 0.5 +(tool_name) -hf ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF +(tool_name) -hf ggml-org/ultravox-v0_5-llama-3_1-8b-GGUF + +# Qwen2-Audio and SeaLLM-Audio +# note: no pre-quantized GGUF this model, as they have very poor result +# ref: https://github.com/ggml-org/llama.cpp/pull/13760 +``` + +**Mixed modalities**: + +```sh +# Qwen2.5 Omni +# Capabilities: audio input, vision input +(tool_name) -hf ggml-org/Qwen2.5-Omni-3B-GGUF +(tool_name) -hf ggml-org/Qwen2.5-Omni-7B-GGUF +``` + +## Finding more models: + +GGUF models on Huggingface with vision capabilities can be found here: https://huggingface.co/models?pipeline_tag=image-text-to-text&sort=trending&search=gguf diff --git a/examples/llava/MobileVLM-README.md b/docs/multimodal/MobileVLM.md similarity index 94% rename from examples/llava/MobileVLM-README.md rename to docs/multimodal/MobileVLM.md index 4f783f3ce..4f5eca619 100644 --- a/examples/llava/MobileVLM-README.md +++ b/docs/multimodal/MobileVLM.md @@ -9,15 +9,15 @@ The implementation is based on llava, and is compatible with llava and mobileVLM Notice: The overall process of model inference for both **MobileVLM** and **MobileVLM_V2** models is the same, but the process of model conversion is a little different. Therefore, using **MobileVLM-1.7B** as an example, the different conversion step will be shown. ## Usage -Build with cmake or run `make llama-llava-cli` to build it. -After building, run: `./llama-llava-cli` to see the usage. For example: +Build the `llama-mtmd-cli` binary. + +After building, run: `./llama-mtmd-cli` to see the usage. For example: ```sh -./llama-llava-cli -m MobileVLM-1.7B/ggml-model-q4_k.gguf \ +./llama-mtmd-cli -m MobileVLM-1.7B/ggml-model-q4_k.gguf \ --mmproj MobileVLM-1.7B/mmproj-model-f16.gguf \ - --image path/to/an/image.jpg \ - -p "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: \nWho is the author of this book? Answer the question using a single word or phrase. ASSISTANT:" + --chat-template deepseek ``` ## Model conversion @@ -33,13 +33,13 @@ git clone https://huggingface.co/openai/clip-vit-large-patch14-336 2. Use `llava_surgery.py` to split the LLaVA model to LLaMA and multimodel projector constituents: ```sh -python ./examples/llava/llava_surgery.py -m path/to/MobileVLM-1.7B +python ./tools/mtmd/llava_surgery.py -m path/to/MobileVLM-1.7B ``` 3. Use `convert_image_encoder_to_gguf.py` with `--projector-type ldp` (for **V2** please use `--projector-type ldpv2`) to convert the LLaVA image encoder to GGUF: ```sh -python ./examples/llava/convert_image_encoder_to_gguf.py \ +python ./tools/mtmd/convert_image_encoder_to_gguf.py \ -m path/to/clip-vit-large-patch14-336 \ --llava-projector path/to/MobileVLM-1.7B/llava.projector \ --output-dir path/to/MobileVLM-1.7B \ @@ -47,7 +47,7 @@ python ./examples/llava/convert_image_encoder_to_gguf.py \ ``` ```sh -python ./examples/llava/convert_image_encoder_to_gguf.py \ +python ./tools/mtmd/convert_image_encoder_to_gguf.py \ -m path/to/clip-vit-large-patch14-336 \ --llava-projector path/to/MobileVLM-1.7B_V2/llava.projector \ --output-dir path/to/MobileVLM-1.7B_V2 \ @@ -69,10 +69,10 @@ Now both the LLaMA part and the image encoder is in the `MobileVLM-1.7B` directo ## Android compile and run ### compile -refer to `examples/llava/android/build_64.sh` +refer to `tools/mtmd/android/build_64.sh` ```sh -mkdir examples/llava/android/build_64 -cd examples/llava/android/build_64 +mkdir tools/mtmd/android/build_64 +cd tools/mtmd/android/build_64 ../build_64.sh ``` ### run on Android @@ -82,7 +82,7 @@ refer to `android/adb_run.sh`, modify resources' `name` and `path` ### case 1 **input** ```sh -/data/local/tmp/llama-llava-cli \ +/data/local/tmp/llama-mtmd-cli \ -m /data/local/tmp/ggml-model-q4_k.gguf \ --mmproj /data/local/tmp/mmproj-model-f16.gguf \ -t 4 \ @@ -102,7 +102,7 @@ llama_print_timings: total time = 34731.93 ms ### case 2 **input** ```sh -/data/local/tmp/llama-llava-cli \ +/data/local/tmp/llama-mtmd-cli \ -m /data/local/tmp/ggml-model-q4_k.gguf \ --mmproj /data/local/tmp/mmproj-model-f16.gguf \ -t 4 \ @@ -123,10 +123,10 @@ llama_print_timings: total time = 34570.79 ms ## Some result on Android with `Snapdragon 778G` chip ### MobileVLM-1.7B case -#### llava-cli release-b2005 +#### mtmd-cli release-b2005 **input** ```sh -/data/local/tmp/llama-llava-cli \ +/data/local/tmp/llama-mtmd-cli \ -m /data/local/tmp/ggml-model-q4_k.gguf \ --mmproj /data/local/tmp/mmproj-model-f16.gguf \ -t 4 \ @@ -147,7 +147,7 @@ llama_print_timings: prompt eval time = 8119.49 ms / 191 tokens ( 42.51 m llama_print_timings: eval time = 1005.75 ms / 14 runs ( 71.84 ms per token, 13.92 tokens per second) llama_print_timings: total time = 28038.34 ms / 205 tokens ``` -#### llava-cli latest-version +#### mtmd-cli latest-version **input** Just the same as above. @@ -169,7 +169,7 @@ llama_print_timings: eval time = 43894.02 ms / 13 runs ( 3376.46 m llama_print_timings: total time = 865441.76 ms / 204 tokens ``` ### MobileVLM_V2-1.7B case -#### llava-cli release-2005b +#### mtmd-cli release-2005b **input** Just the same as above. @@ -200,7 +200,7 @@ make GGML_CUDA=1 CUDA_DOCKER_ARCH=sm_87 GGML_CUDA_F16=1 -j 32 ### case 1 **input** ```sh -./llama-llava-cli \ +./llama-mtmd-cli \ -m /data/local/tmp/ggml-model-q4_k.gguf \ --mmproj /data/local/tmp/mmproj-model-f16.gguf \ --image /data/local/tmp/demo.jpeg \ @@ -224,7 +224,7 @@ llama_print_timings: total time = 1352.63 ms / 252 tokens ### case 2 **input** ```sh -./llama-llava-cli \ +./llama-mtmd-cli \ -m /data/local/tmp/ggml-model-q4_k.gguf \ --mmproj /data/local/tmp/mmproj-model-f16.gguf \ -p "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: \nWhat is in the image? ASSISTANT:" \ diff --git a/examples/llava/README-gemma3.md b/docs/multimodal/gemma3.md similarity index 54% rename from examples/llava/README-gemma3.md rename to docs/multimodal/gemma3.md index 3c25ee258..110a36f40 100644 --- a/examples/llava/README-gemma3.md +++ b/docs/multimodal/gemma3.md @@ -11,26 +11,27 @@ You can use pre-quantized model from [ggml-org](https://huggingface.co/ggml-org) ```bash # build cmake -B build -cmake --build build --target llama-gemma3-cli +cmake --build build --target llama-mtmd-cli # alternatively, install from brew (MacOS) brew install llama.cpp # run it -llama-gemma3-cli -hf ggml-org/gemma-3-4b-it-GGUF -llama-gemma3-cli -hf ggml-org/gemma-3-12b-it-GGUF -llama-gemma3-cli -hf ggml-org/gemma-3-27b-it-GGUF +llama-mtmd-cli -hf ggml-org/gemma-3-4b-it-GGUF +llama-mtmd-cli -hf ggml-org/gemma-3-12b-it-GGUF +llama-mtmd-cli -hf ggml-org/gemma-3-27b-it-GGUF # note: 1B model does not support vision ``` ## How to get mmproj.gguf? +Simply to add `--mmproj` in when converting model via `convert_hf_to_gguf.py`: + ```bash cd gemma-3-4b-it -python ../llama.cpp/examples/llava/gemma3_convert_encoder_to_gguf.py . - -# output file is mmproj.gguf +python ../llama.cpp/convert_hf_to_gguf.py --outfile model.gguf --outtype f16 --mmproj . +# output file: mmproj-model.gguf ``` ## How to run it? @@ -43,8 +44,8 @@ What you need: ```bash # build cmake -B build -cmake --build build --target llama-gemma3-cli +cmake --build build --target llama-mtmd-cli # run it -./build/bin/llama-gemma3-cli -m {text_model}.gguf --mmproj mmproj.gguf --image your_image.jpg +./build/bin/llama-mtmd-cli -m {text_model}.gguf --mmproj mmproj.gguf --image your_image.jpg ``` diff --git a/examples/llava/README-glmedge.md b/docs/multimodal/glmedge.md similarity index 67% rename from examples/llava/README-glmedge.md rename to docs/multimodal/glmedge.md index 603d01474..7bae83150 100644 --- a/examples/llava/README-glmedge.md +++ b/docs/multimodal/glmedge.md @@ -3,12 +3,12 @@ Currently this implementation supports [glm-edge-v-2b](https://huggingface.co/THUDM/glm-edge-v-2b) and [glm-edge-v-5b](https://huggingface.co/THUDM/glm-edge-v-5b). ## Usage -Build with cmake or run `make llama-llava-cli` to build it. +Build the `llama-mtmd-cli` binary. -After building, run: `./llama-llava-cli` to see the usage. For example: +After building, run: `./llama-mtmd-cli` to see the usage. For example: ```sh -./llama-llava-cli -m model_path/ggml-model-f16.gguf --mmproj model_path/mmproj-model-f16.gguf --image img_path/image.jpg -p "<|system|>\n system prompt <|user|>\n prompt <|assistant|>\n" +./llama-mtmd-cli -m model_path/ggml-model-f16.gguf --mmproj model_path/mmproj-model-f16.gguf ``` **note**: A lower temperature like 0.1 is recommended for better quality. add `--temp 0.1` to the command to do so. @@ -25,13 +25,13 @@ git clone https://huggingface.co/THUDM/glm-edge-v-5b or https://huggingface.co/T 2. Use `glmedge-surgery.py` to split the GLMV-EDGE model to LLM and multimodel projector constituents: ```sh -python ./examples/llava/glmedge-surgery.py -m ../model_path +python ./tools/mtmd/glmedge-surgery.py -m ../model_path ``` 4. Use `glmedge-convert-image-encoder-to-gguf.py` to convert the GLMV-EDGE image encoder to GGUF: ```sh -python ./examples/llava/glmedge-convert-image-encoder-to-gguf.py -m ../model_path --llava-projector ../model_path/glm.projector --output-dir ../model_path +python ./tools/mtmd/glmedge-convert-image-encoder-to-gguf.py -m ../model_path --llava-projector ../model_path/glm.projector --output-dir ../model_path ``` 5. Use `examples/convert_hf_to_gguf.py` to convert the LLM part of GLMV-EDGE to GGUF: diff --git a/examples/llava/README-granitevision.md b/docs/multimodal/granitevision.md similarity index 92% rename from examples/llava/README-granitevision.md rename to docs/multimodal/granitevision.md index f08a21cc1..3118fe0cd 100644 --- a/examples/llava/README-granitevision.md +++ b/docs/multimodal/granitevision.md @@ -176,15 +176,11 @@ Note that currently you cannot quantize the visual encoder because granite visio ### 5. Running the Model in Llama cpp -Build llama cpp normally; you should have a target binary named `llama-llava-cli`, which you can pass two binaries to. As an example, we pass the the llama.cpp banner. +Build llama cpp normally; you should have a target binary named `llama-mtmd-cli`, which you can pass two binaries to. As an example, we pass the the llama.cpp banner. ```bash -$ ./build/bin/llama-llava-cli -m $LLM_GGUF_PATH \ +$ ./build/bin/llama-mtmd-cli -m $LLM_GGUF_PATH \ --mmproj $VISUAL_GGUF_PATH \ - --image ./media/llama0-banner.png \ -c 16384 \ - -p "<|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<|user|>\n\\nWhat does the text in this image say?\n<|assistant|>\n" \ --temp 0 ``` - -Sample output: `The text in the image reads "LLAMA C++ Can it run DOOM Llama?"` diff --git a/examples/llava/README.md b/docs/multimodal/llava.md similarity index 69% rename from examples/llava/README.md rename to docs/multimodal/llava.md index 0e3c32032..12354ab60 100644 --- a/examples/llava/README.md +++ b/docs/multimodal/llava.md @@ -11,12 +11,14 @@ For llava-1.6 a variety of prepared gguf models are available as well [7b-34b](h After API is confirmed, more models will be supported / uploaded. ## Usage -Build with cmake or run `make llama-llava-cli` to build it. +Build the `llama-mtmd-cli` binary. -After building, run: `./llama-llava-cli` to see the usage. For example: +After building, run: `./llama-mtmd-cli` to see the usage. For example: ```sh -./llama-llava-cli -m ../llava-v1.5-7b/ggml-model-f16.gguf --mmproj ../llava-v1.5-7b/mmproj-model-f16.gguf --image path/to/an/image.jpg +./llama-mtmd-cli -m ../llava-v1.5-7b/ggml-model-f16.gguf \ + --mmproj ../llava-v1.5-7b/mmproj-model-f16.gguf \ + --chat-template vicuna ``` **note**: A lower temperature like 0.1 is recommended for better quality. add `--temp 0.1` to the command to do so. @@ -35,19 +37,19 @@ git clone https://huggingface.co/openai/clip-vit-large-patch14-336 2. Install the required Python packages: ```sh -pip install -r examples/llava/requirements.txt +pip install -r tools/mtmd/requirements.txt ``` 3. Use `llava_surgery.py` to split the LLaVA model to LLaMA and multimodel projector constituents: ```sh -python ./examples/llava/llava_surgery.py -m ../llava-v1.5-7b +python ./tools/mtmd/llava_surgery.py -m ../llava-v1.5-7b ``` 4. Use `convert_image_encoder_to_gguf.py` to convert the LLaVA image encoder to GGUF: ```sh -python ./examples/llava/convert_image_encoder_to_gguf.py -m ../clip-vit-large-patch14-336 --llava-projector ../llava-v1.5-7b/llava.projector --output-dir ../llava-v1.5-7b +python ./tools/mtmd/convert_image_encoder_to_gguf.py -m ../clip-vit-large-patch14-336 --llava-projector ../llava-v1.5-7b/llava.projector --output-dir ../llava-v1.5-7b ``` 5. Use `examples/convert_legacy_llama.py` to convert the LLaMA part of LLaVA to GGUF: @@ -67,12 +69,12 @@ git clone https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b 2) Install the required Python packages: ```sh -pip install -r examples/llava/requirements.txt +pip install -r tools/mtmd/requirements.txt ``` 3) Use `llava_surgery_v2.py` which also supports llava-1.5 variants pytorch as well as safetensor models: ```console -python examples/llava/llava_surgery_v2.py -C -m ../llava-v1.6-vicuna-7b/ +python tools/mtmd/llava_surgery_v2.py -C -m ../llava-v1.6-vicuna-7b/ ``` - you will find a llava.projector and a llava.clip file in your model directory @@ -86,7 +88,7 @@ curl -s -q https://huggingface.co/cmp-nct/llava-1.6-gguf/raw/main/config_vit.jso 5) Create the visual gguf model: ```console -python ./examples/llava/convert_image_encoder_to_gguf.py -m vit --llava-projector vit/llava.projector --output-dir vit --clip-model-is-vision +python ./tools/mtmd/convert_image_encoder_to_gguf.py -m vit --llava-projector vit/llava.projector --output-dir vit --clip-model-is-vision ``` - This is similar to llava-1.5, the difference is that we tell the encoder that we are working with the pure vision model part of CLIP @@ -97,7 +99,7 @@ python ./examples/convert_legacy_llama.py ../llava-v1.6-vicuna-7b/ --skip-unknow 7) And finally we can run the llava cli using the 1.6 model version: ```console -./llama-llava-cli -m ../llava-v1.6-vicuna-7b/ggml-model-f16.gguf --mmproj vit/mmproj-model-f16.gguf --image some-image.jpg -c 4096 +./llama-mtmd-cli -m ../llava-v1.6-vicuna-7b/ggml-model-f16.gguf --mmproj vit/mmproj-model-f16.gguf ``` **note** llava-1.6 needs more context than llava-1.5, at least 3000 is needed (just run it at -c 4096) @@ -122,17 +124,9 @@ model.language_model.save_pretrained(llm_export_path) Then, you can convert the LLM using the `convert_hf_to_gguf.py` script, which handles more LLM architectures. -## llava-cli templating and llava-1.6 prompting +## Chat template -llava-1.5 models all use the same vicuna prompt, here you can just add your image question like `-p "Provide a full description."` -For llava-1.5 models which are not vicuna (mistral and Yi) you need to adapt system prompt as well as user prompt, for this purpose llava-cli has a basic templating system: - -**For Mistral and using llava-cli binary:** -Add this: `-p "\nUSER:\nProvide a full description.\nASSISTANT:\n"` -The mistral template for llava-1.6 seems to be no system print and a USER/ASSISTANT role - -**For the 34B this should work:** -Add this: `-e -p <|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n\nProvide a full description.<|im_end|><|im_start|>assistant\n` +For llava-1.5 and llava-1.6, you need to use `vicuna` chat template. Simply add `--chat-template vicuna` to activate this template. ## How to know if you are running in llava-1.5 or llava-1.6 mode @@ -147,12 +141,3 @@ When running llava-cli you will see a visual information right before the prompt Alternatively just pay notice to how many "tokens" have been used for your prompt, it will also show 1000+ tokens for llava-1.6 - - - - -## TODO - -- [x] Support non-CPU backend for the image encoding part. -- [ ] Support different sampling methods. -- [ ] Support more model variants. diff --git a/examples/llava/README-minicpmo2.6.md b/docs/multimodal/minicpmo2.6.md similarity index 57% rename from examples/llava/README-minicpmo2.6.md rename to docs/multimodal/minicpmo2.6.md index 48c423238..8c6db8efe 100644 --- a/examples/llava/README-minicpmo2.6.md +++ b/docs/multimodal/minicpmo2.6.md @@ -29,8 +29,8 @@ cmake --build build --config Release Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-o-2_6-gguf) by us) ```bash -python ./examples/llava/minicpmv-surgery.py -m ../MiniCPM-o-2_6 -python ./examples/llava/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-o-2_6 --minicpmv-projector ../MiniCPM-o-2_6/minicpmv.projector --output-dir ../MiniCPM-o-2_6/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 4 +python ./tools/mtmd/minicpmv-surgery.py -m ../MiniCPM-o-2_6 +python ./tools/mtmd/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-o-2_6 --minicpmv-projector ../MiniCPM-o-2_6/minicpmv.projector --output-dir ../MiniCPM-o-2_6/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 4 python ./convert_hf_to_gguf.py ../MiniCPM-o-2_6/model # quantize int4 version @@ -40,9 +40,9 @@ python ./convert_hf_to_gguf.py ../MiniCPM-o-2_6/model Inference on Linux or Mac ```bash -# run f16 version -./build/bin/llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" +# run in single-turn mode +./build/bin/llama-mtmd-cli -m ../MiniCPM-o-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" -# run quantized int4 version -./build/bin/llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" +# run in conversation mode +./build/bin/llama-mtmd-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf ``` diff --git a/examples/llava/README-minicpmv2.5.md b/docs/multimodal/minicpmv2.5.md similarity index 54% rename from examples/llava/README-minicpmv2.5.md rename to docs/multimodal/minicpmv2.5.md index 6bfe7abd1..19b439607 100644 --- a/examples/llava/README-minicpmv2.5.md +++ b/docs/multimodal/minicpmv2.5.md @@ -28,8 +28,8 @@ cmake --build build --config Release Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5-gguf) by us) ```bash -python ./examples/llava/minicpmv-surgery.py -m ../MiniCPM-Llama3-V-2_5 -python ./examples/llava/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-Llama3-V-2_5 --minicpmv-projector ../MiniCPM-Llama3-V-2_5/minicpmv.projector --output-dir ../MiniCPM-Llama3-V-2_5/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 2 +python ./tools/mtmd/minicpmv-surgery.py -m ../MiniCPM-Llama3-V-2_5 +python ./tools/mtmd/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-Llama3-V-2_5 --minicpmv-projector ../MiniCPM-Llama3-V-2_5/minicpmv.projector --output-dir ../MiniCPM-Llama3-V-2_5/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 2 python ./convert_hf_to_gguf.py ../MiniCPM-Llama3-V-2_5/model # quantize int4 version @@ -39,9 +39,9 @@ python ./convert_hf_to_gguf.py ../MiniCPM-Llama3-V-2_5/model Inference on Linux or Mac ```bash -# run f16 version -./build/bin/llama-minicpmv-cli -m ../MiniCPM-Llama3-V-2_5/model/model-8B-F16.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" +# run in single-turn mode +./build/bin/llama-mtmd-cli -m ../MiniCPM-Llama3-V-2_5/model/model-8B-F16.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" -# run quantized int4 version -./build/bin/llama-minicpmv-cli -m ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" +# run in conversation mode +./build/bin/llama-mtmd-cli -m ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf ``` diff --git a/examples/llava/README-minicpmv2.6.md b/docs/multimodal/minicpmv2.6.md similarity index 54% rename from examples/llava/README-minicpmv2.6.md rename to docs/multimodal/minicpmv2.6.md index 2df39cdba..15c1bbd12 100644 --- a/examples/llava/README-minicpmv2.6.md +++ b/docs/multimodal/minicpmv2.6.md @@ -28,8 +28,8 @@ cmake --build build --config Release Convert PyTorch model to gguf files (You can also download the converted [gguf](https://huggingface.co/openbmb/MiniCPM-V-2_6-gguf) by us) ```bash -python ./examples/llava/minicpmv-surgery.py -m ../MiniCPM-V-2_6 -python ./examples/llava/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-V-2_6 --minicpmv-projector ../MiniCPM-V-2_6/minicpmv.projector --output-dir ../MiniCPM-V-2_6/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 3 +python ./tools/mtmd/minicpmv-surgery.py -m ../MiniCPM-V-2_6 +python ./tools/mtmd/minicpmv-convert-image-encoder-to-gguf.py -m ../MiniCPM-V-2_6 --minicpmv-projector ../MiniCPM-V-2_6/minicpmv.projector --output-dir ../MiniCPM-V-2_6/ --image-mean 0.5 0.5 0.5 --image-std 0.5 0.5 0.5 --minicpmv_version 3 python ./convert_hf_to_gguf.py ../MiniCPM-V-2_6/model # quantize int4 version @@ -39,9 +39,9 @@ python ./convert_hf_to_gguf.py ../MiniCPM-V-2_6/model Inference on Linux or Mac ```bash -# run f16 version -./build/bin/llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" +# run in single-turn mode +./build/bin/llama-mtmd-cli -m ../MiniCPM-V-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" -# run quantized int4 version -./build/bin/llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?" +# run in conversation mode +./build/bin/llama-mtmd-cli -m ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf ``` diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 66cfab2c3..49e4d2cf8 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -12,60 +12,30 @@ llama_add_compile_flags() # examples -include_directories(${CMAKE_CURRENT_SOURCE_DIR}) - if (EMSCRIPTEN) else() - add_subdirectory(batched-bench) add_subdirectory(batched) add_subdirectory(embedding) add_subdirectory(eval-callback) - if (NOT WIN32) - # disabled on Windows because it uses internal functions not exported with LLAMA_API - add_subdirectory(gbnf-validator) - endif() - add_subdirectory(gguf-hash) - add_subdirectory(gguf-split) add_subdirectory(gguf) add_subdirectory(gritlm) - add_subdirectory(imatrix) - add_subdirectory(infill) - add_subdirectory(llama-bench) add_subdirectory(lookahead) add_subdirectory(lookup) - add_subdirectory(main) add_subdirectory(parallel) add_subdirectory(passkey) - add_subdirectory(perplexity) - add_subdirectory(quantize) add_subdirectory(retrieval) - if (LLAMA_BUILD_SERVER) - add_subdirectory(server) - endif() add_subdirectory(save-load-state) - add_subdirectory(run) add_subdirectory(simple) add_subdirectory(simple-chat) add_subdirectory(speculative) add_subdirectory(speculative-simple) - add_subdirectory(tokenize) - add_subdirectory(tts) add_subdirectory(gen-docs) + add_subdirectory(training) if (NOT GGML_BACKEND_DL) - # these examples use the backends directly and cannot be built with dynamic loading add_subdirectory(convert-llama2c-to-ggml) - add_subdirectory(cvector-generator) - add_subdirectory(export-lora) - if (NOT WIN32) - # disabled on Windows because it uses internal functions not exported with LLAMA_API - add_subdirectory(quantize-stats) - endif() - add_subdirectory(llava) - if (GGML_RPC) - add_subdirectory(rpc) - endif() + # these examples use the backends directly and cannot be built with dynamic loading if (GGML_SYCL) add_subdirectory(sycl) endif() diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 514989e34..fd90bbec5 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -116,7 +116,7 @@ if llama_decode(context, batch) != 0 { } for i in 1 ..< n_parallel { - llama_kv_self_seq_cp(context, 0, Int32(i), 0, batch.n_tokens) + llama_memory_seq_cp(llama_get_memory(context), 0, Int32(i), 0, batch.n_tokens) } if n_parallel > 1 { diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 6f0890415..681929d27 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -35,23 +35,14 @@ static void batch_add_seq(llama_batch & batch, const std::vector & toke static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) { const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); - const struct llama_model * model = llama_get_model(ctx); // clear previous kv_cache values (irrelevant for embeddings) - llama_kv_self_clear(ctx); + llama_memory_clear(llama_get_memory(ctx), true); // run model LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); - if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) { - // encoder-only model - if (llama_encode(ctx, batch) < 0) { - LOG_ERR("%s : failed to encode\n", __func__); - } - } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) { - // decoder-only model - if (llama_decode(ctx, batch) < 0) { - LOG_ERR("%s : failed to decode\n", __func__); - } + if (llama_decode(ctx, batch) < 0) { + LOG_ERR("%s : failed to process\n", __func__); } for (int i = 0; i < batch.n_tokens; i++) { @@ -89,6 +80,13 @@ int main(int argc, char ** argv) { common_init(); params.embedding = true; + + // utilize the full context + if (params.n_batch < params.n_ctx) { + LOG_WRN("%s: setting batch size to %d\n", __func__, params.n_ctx); + params.n_batch = params.n_ctx; + } + // For non-causal models, batch size must be equal to ubatch size params.n_ubatch = params.n_batch; @@ -134,7 +132,6 @@ int main(int argc, char ** argv) { // max batch size const uint64_t n_batch = params.n_batch; - GGML_ASSERT(params.n_batch >= params.n_ctx); // tokenize the prompts and trim std::vector> inputs; @@ -239,9 +236,24 @@ int main(int argc, char ** argv) { LOG("\n"); } } else if (pooling_type == LLAMA_POOLING_TYPE_RANK) { + const uint32_t n_cls_out = llama_model_n_cls_out(model); + std::vector cls_out_labels; + + for (uint32_t i = 0; i < n_cls_out; i++) { + const char * label = llama_model_cls_label(model, i); + const std::string label_i(label == nullptr ? "" : label); + cls_out_labels.emplace_back(label_i.empty() ? std::to_string(i) : label_i); + } + for (int j = 0; j < n_embd_count; j++) { - // NOTE: if you change this log - update the tests in ci/run.sh - LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]); + for (uint32_t i = 0; i < n_cls_out; i++) { + // NOTE: if you change this log - update the tests in ci/run.sh + if (n_cls_out == 1) { + LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]); + } else { + LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd + i], cls_out_labels[i].c_str()); + } + } } } else { // print the first part of the embeddings or for a single prompt, the full embedding diff --git a/examples/gbnf-validator/CMakeLists.txt b/examples/gbnf-validator/CMakeLists.txt deleted file mode 100644 index d2cb524c0..000000000 --- a/examples/gbnf-validator/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -set(TARGET llama-gbnf-validator) -add_executable(${TARGET} gbnf-validator.cpp) -install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 539bc4d60..bdab052c3 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -41,12 +41,11 @@ static std::vector> encode(llama_context * ctx, const std::ve // add input to batch (this increments n_tokens) for (int32_t j = 0; j < n_toks; j++) { - common_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst); + common_batch_add(batch, inputs[j], j, { 0 }, true); } // clear previous kv_cache values (irrelevant for embeddings) - llama_kv_self_clear(ctx); - llama_set_embeddings(ctx, true); + llama_memory_clear(llama_get_memory(ctx), true); llama_set_causal_attn(ctx, false); // run model @@ -102,8 +101,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std llama_token eos_token = llama_vocab_eos(vocab); - llama_kv_self_clear(ctx); - llama_set_embeddings(ctx, false); + llama_memory_clear(llama_get_memory(ctx), true); llama_set_causal_attn(ctx, true); llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); @@ -166,6 +164,8 @@ int main(int argc, char * argv[]) { llama_model_params mparams = common_model_params_to_llama(params); llama_context_params cparams = common_context_params_to_llama(params); + cparams.embeddings = true; + llama_backend_init(); llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams); @@ -213,6 +213,8 @@ int main(int argc, char * argv[]) { std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[1].c_str(), cosine_sim_q1_d1); } + llama_set_embeddings(ctx, false); + // ### Generation ### // GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction { diff --git a/examples/infill/CMakeLists.txt b/examples/infill/CMakeLists.txt deleted file mode 100644 index fb26628d8..000000000 --- a/examples/infill/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -set(TARGET llama-infill) -add_executable(${TARGET} infill.cpp) -install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/infill/README.md b/examples/infill/README.md deleted file mode 100644 index df4d976f2..000000000 --- a/examples/infill/README.md +++ /dev/null @@ -1,47 +0,0 @@ -# llama.cpp/example/infill - -This example shows how to use the infill mode with Code Llama models supporting infill mode. -Currently the 7B and 13B models support infill mode. - -Infill supports most of the options available in the main example. - -For further information have a look at the main README.md in llama.cpp/example/main/README.md - -## Common Options - -In this section, we cover the most commonly used options for running the `infill` program with the LLaMA models: - -- `-m FNAME, --model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.bin`). -- `-i, --interactive`: Run the program in interactive mode, allowing you to provide input directly and receive real-time responses. -- `-n N, --n-predict N`: Set the number of tokens to predict when generating text. Adjusting this value can influence the length of the generated text. -- `-c N, --ctx-size N`: Set the size of the prompt context. The default is 4096, but if a LLaMA model was built with a longer context, increasing this value will provide better results for longer input/inference. -- `--spm-infill`: Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. - -## Input Prompts - -The `infill` program provides several ways to interact with the LLaMA models using input prompts: - -- `--in-prefix PROMPT_BEFORE_CURSOR`: Provide the prefix directly as a command-line option. -- `--in-suffix PROMPT_AFTER_CURSOR`: Provide the suffix directly as a command-line option. -- `--interactive-first`: Run the program in interactive mode and wait for input right away. (More on this below.) - -## Interaction - -The `infill` program offers a seamless way to interact with LLaMA models, allowing users to receive real-time infill suggestions. The interactive mode can be triggered using `--interactive`, and `--interactive-first` - -### Interaction Options - -- `-i, --interactive`: Run the program in interactive mode, allowing users to get real time code suggestions from model. -- `--interactive-first`: Run the program in interactive mode and immediately wait for user input before starting the text generation. -- `--color`: Enable colorized output to differentiate visually distinguishing between prompts, user input, and generated text. - -### Example - -Download a model that supports infill, for example CodeLlama: -```console -scripts/hf.sh --repo TheBloke/CodeLlama-13B-GGUF --file codellama-13b.Q5_K_S.gguf --outdir models -``` - -```bash -./llama-infill -t 10 -ngl 0 -m models/codellama-13b.Q5_K_S.gguf -c 4096 --temp 0.7 --repeat_penalty 1.1 -n 20 --in-prefix "def helloworld():\n print(\"hell" --in-suffix "\n print(\"goodbye world\")\n " -``` diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp deleted file mode 100644 index 4e2f7b727..000000000 --- a/examples/infill/infill.cpp +++ /dev/null @@ -1,590 +0,0 @@ -#include "arg.h" -#include "common.h" -#include "console.h" -#include "sampling.h" -#include "log.h" -#include "llama.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) -#include -#include -#elif defined (_WIN32) -#define WIN32_LEAN_AND_MEAN -#ifndef NOMINMAX -#define NOMINMAX -#endif -#include -#include -#endif - -#if defined(_MSC_VER) -#pragma warning(disable: 4244 4267) // possible loss of data -#endif - -static llama_context ** g_ctx; -static llama_model ** g_model; -static common_sampler ** g_smpl; -static common_params * g_params; -static std::vector * g_input_tokens; -static std::ostringstream * g_output_ss; -static std::vector * g_output_tokens; - -static bool is_interacting = false; - -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) -static void sigint_handler(int signo) { - if (signo == SIGINT) { - if (!is_interacting) { - is_interacting = true; - } else { - console::cleanup(); - LOG("\n"); - common_perf_print(*g_ctx, *g_smpl); - - // make sure all logs are flushed - LOG("Interrupted by user\n"); - common_log_pause(common_log_main()); - - _exit(130); - } - } -} -#endif - -int main(int argc, char ** argv) { - common_params params; - g_params = ¶ms; - - if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_INFILL)) { - return 1; - } - - common_init(); - - auto & sparams = params.sampling; - - console::init(params.simple_io, params.use_color); - atexit([]() { console::cleanup(); }); - - if (params.logits_all) { - LOG_ERR("\n************\n"); - LOG_ERR("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__); - LOG_ERR("************\n\n"); - - return 0; - } - - if (params.embedding) { - LOG_ERR("\n************\n"); - LOG_ERR("%s: please use the 'embedding' tool for embedding calculations\n", __func__); - LOG_ERR("************\n\n"); - - return 0; - } - - if (params.n_ctx != 0 && params.n_ctx < 8) { - LOG_WRN("%s: minimum context size is 8, using minimum size.\n", __func__); - params.n_ctx = 8; - } - - if (!params.interactive_first && (params.input_prefix.empty() && params.input_suffix.empty())) { - LOG_ERR("\n************\n"); - LOG_ERR("%s: please use '--interactive_first' or specify '--in_prefix' and/or '--in_suffix'\n", __func__); - LOG_ERR("************\n\n"); - - return 0; - } - - if (params.rope_freq_base != 0.0) { - LOG_WRN("%s: changing RoPE frequency base to %g.\n", __func__, params.rope_freq_base); - } - - if (params.rope_freq_scale != 0.0) { - LOG_WRN("%s: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale); - } - - LOG_INF("%s: llama backend init\n", __func__); - llama_backend_init(); - llama_numa_init(params.numa); - - llama_model * model = nullptr; - llama_context * ctx = nullptr; - common_sampler * smpl = nullptr; - - g_model = &model; - g_ctx = &ctx; - g_smpl = &smpl; - - // load the model and apply lora adapter, if any - LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__); - common_init_result llama_init = common_init_from_params(params); - - model = llama_init.model.get(); - ctx = llama_init.context.get(); - - if (model == NULL) { - LOG_ERR("%s: unable to load model\n", __func__); - return 1; - } - - const llama_vocab * vocab = llama_model_get_vocab(model); - - const int n_ctx_train = llama_model_n_ctx_train(model); - const int n_ctx = llama_n_ctx(ctx); - LOG_DBG("n_ctx: %d\n", n_ctx); - - if (n_ctx > n_ctx_train) { - LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx); - } - - // print system information - { - LOG_INF("\n"); - LOG_INF("%s\n", common_params_get_system_info(params).c_str()); - } - const bool add_bos = llama_vocab_get_add_bos(vocab); - GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); - - std::vector embd_inp; - std::vector embd_end; - std::vector inp_pfx = common_tokenize(ctx, params.input_prefix, false); - std::vector inp_sfx = common_tokenize(ctx, params.input_suffix, false); - - GGML_ASSERT(llama_vocab_fim_pre(vocab) >= 0); - GGML_ASSERT(llama_vocab_fim_suf(vocab) >= 0); - - inp_pfx.insert(inp_pfx.begin(), llama_vocab_fim_pre(vocab)); - inp_sfx.insert(inp_sfx.begin(), llama_vocab_fim_suf(vocab)); - - embd_inp = params.spm_infill ? inp_sfx : inp_pfx; - embd_end = params.spm_infill ? inp_pfx : inp_sfx; - if (add_bos) { - embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab)); - } - embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); - - const llama_token middle_token = llama_vocab_fim_mid(vocab); - if (middle_token >= 0) { - embd_inp.push_back(middle_token); - } - - LOG_DBG("add_bos: %d\n", add_bos); - LOG_DBG("prefix: \"%s\"\n", params.input_prefix.c_str()); - LOG_DBG("suffix: \"%s\"\n", params.input_suffix.c_str()); - LOG_DBG("tokens: %s\n", string_from(ctx, embd_inp).c_str()); - - // Should not run without any tokens - if (embd_inp.empty()) { - embd_inp.push_back(llama_vocab_bos(vocab)); - LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str()); - } - - if ((int) embd_inp.size() > n_ctx - 4) { - LOG_ERR("%s: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4); - return 1; - } - - // number of tokens to keep when resetting context - if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size()) { - params.n_keep = (int)embd_inp.size(); - } - - LOG_INF("inp_pfx: %s\n", string_from(ctx, inp_pfx).c_str()); - LOG_INF("inp_sfx: %s\n", string_from(ctx, inp_sfx).c_str()); - - // enable interactive mode if interactive start is specified - if (params.interactive_first) { - params.interactive = true; - } - - if (params.verbose_prompt) { - LOG_INF("\n"); - LOG_INF("%s: prompt: '%s'\n", __func__, params.prompt.c_str()); - LOG_INF("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); - for (int i = 0; i < (int) embd_inp.size(); i++) { - LOG_INF("%6d -> '%s'\n", embd_inp[i], common_token_to_piece(ctx, embd_inp[i]).c_str()); - } - - if (params.n_keep > 0) { - LOG_INF("%s: static prompt based on n_keep: '", __func__); - for (int i = 0; i < params.n_keep; i++) { - LOG_CNT("%s", common_token_to_piece(ctx, embd_inp[i]).c_str()); - } - LOG_CNT("'\n"); - } - LOG_INF("\n"); - } - - if (params.interactive) { -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) - struct sigaction sigint_action; - sigint_action.sa_handler = sigint_handler; - sigemptyset (&sigint_action.sa_mask); - sigint_action.sa_flags = 0; - sigaction(SIGINT, &sigint_action, NULL); -#elif defined (_WIN32) - auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { - return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false; - }; - SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); -#endif - - LOG_INF("%s: interactive mode on.\n", __func__); - - if (params.input_prefix_bos) { - LOG_INF("Input prefix with BOS\n"); - } - - if (!params.input_prefix.empty()) { - LOG_INF("Input prefix: '%s'\n", params.input_prefix.c_str()); - } - - if (!params.input_suffix.empty()) { - LOG_INF("Input suffix: '%s'\n", params.input_suffix.c_str()); - } - } - smpl = common_sampler_init(model, sparams); - - LOG_INF("sampler seed: %u\n", common_sampler_get_seed(smpl)); - LOG_INF("sampler params: \n%s\n", sparams.print().c_str()); - LOG_INF("sampler chain: %s\n", common_sampler_print(smpl).c_str()); - - LOG_INF("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); - - LOG_INF("\n"); - LOG_INF("\n##### Infill mode #####\n\n"); - if (params.interactive) { - const char *control_message; - if (params.multiline_input) { - control_message = " - To return control to LLaMA, end your input with '\\'.\n" - " - To return control without starting a new line, end your input with '/'.\n"; - } else { - control_message = " - Press Return to return control to LLaMA.\n" - " - To return control without starting a new line, end your input with '/'.\n" - " - If you want to submit another line, end your input with '\\'.\n"; - } - LOG_INF("== Running in interactive mode. ==\n"); -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) - LOG_INF( " - Press Ctrl+C to interject at any time.\n"); -#endif - LOG_INF( "%s\n", control_message); - - is_interacting = params.interactive_first; - } - - bool input_echo = true; - - int n_past = 0; - int n_remain = params.n_predict; - int n_consumed = 0; - - std::vector input_tokens; g_input_tokens = &input_tokens; - std::vector output_tokens; g_output_tokens = &output_tokens; - std::ostringstream output_ss; g_output_ss = &output_ss; - - // the first thing we will do is to output the prompt, so set color accordingly - console::set_display(console::prompt); - - std::vector embd; - - while (n_remain != 0 || params.interactive) { - // predict - if (!embd.empty()) { - // Note: n_ctx - 4 here is to match the logic for commandline prompt handling via - // --prompt or --file which uses the same value. - int max_embd_size = n_ctx - 4; - - // Ensure the input doesn't exceed the context size by truncating embd if necessary. - if ((int) embd.size() > max_embd_size) { - const int skipped_tokens = (int) embd.size() - max_embd_size; - embd.resize(max_embd_size); - - console::set_display(console::error); - LOG_WRN("<>", skipped_tokens, skipped_tokens != 1 ? "s" : ""); - console::set_display(console::reset); - } - - // infinite text generation via context swapping - // if we run out of context: - // - take the n_keep first tokens from the original prompt (via n_past) - // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches - if (n_past + (int) embd.size() > n_ctx) { - if (params.n_predict == -2) { - LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); - break; - } - - const int n_left = n_past - params.n_keep - 1; - const int n_discard = n_left/2; - - LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", - n_past, n_left, n_ctx, params.n_keep, n_discard); - - llama_kv_self_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); - llama_kv_self_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); - - n_past -= n_discard; - - LOG_DBG("after swap: n_past = %d\n", n_past); - - LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str()); - - } - - // evaluate tokens in batches - // embd is typically prepared beforehand to fit within a batch, but not always - for (int i = 0; i < (int) embd.size(); i += params.n_batch) { - int n_eval = (int) embd.size() - i; - if (n_eval > params.n_batch) { - n_eval = params.n_batch; - } - - LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - - if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) { - LOG_ERR("%s : failed to eval\n", __func__); - return 1; - } - - n_past += n_eval; - - LOG_DBG("n_past = %d\n", n_past); - } - - } - - embd.clear(); - - if ((int) embd_inp.size() <= n_consumed && !is_interacting) { - const llama_token id = common_sampler_sample(smpl, ctx, -1); - - common_sampler_accept(smpl, id, true); - - // LOG_DBG("last: %s\n", string_from(ctx, smpl->prev.to_vector()).c_str()); - - embd.push_back(id); - - // echo this to console - input_echo = true; - - // decrement remaining sampling budget - --n_remain; - - LOG_DBG("n_remain: %d\n", n_remain); - } else { - // some user input remains from prompt or interaction, forward it to processing - LOG_DBG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed); - while ((int) embd_inp.size() > n_consumed) { - embd.push_back(embd_inp[n_consumed]); - - // push the prompt in the sampling context in order to apply repetition penalties later - // for the prompt, we don't apply grammar rules - common_sampler_accept(smpl, embd_inp[n_consumed], false); - - ++n_consumed; - if ((int) embd.size() >= params.n_batch) { - break; - } - } - } - - // display text - if (input_echo) { - for (auto id : embd) { - const std::string token_str = common_token_to_piece(ctx, id); - LOG("%s", token_str.c_str()); - - if (embd.size() > 1) { - input_tokens.push_back(id); - } else { - output_tokens.push_back(id); - output_ss << token_str; - } - } - } - // reset color to default if we there is no pending user input - if (input_echo && (int) embd_inp.size() == n_consumed) { - console::set_display(console::reset); - } - - // if not currently processing queued inputs; - if ((int) embd_inp.size() <= n_consumed) { - // deal with eot token in infill mode - if ((common_sampler_last(smpl) == llama_vocab_eot(vocab) || is_interacting) && params.interactive){ - if (is_interacting && !params.interactive_first) { - // print an eot token - LOG("%s", common_token_to_piece(ctx, llama_vocab_eot(vocab)).c_str()); - } - LOG("\n"); - console::set_display(console::user_input); - std::string buffer; - std::string line; - bool another_line=true; - // set a new prefix via stdin - do { - another_line = console::readline(line, params.multiline_input); - buffer += line; - } while (another_line); - // check if we got an empty line, if so we use the old input - if (!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) { - params.input_prefix = buffer; - } - buffer.clear(); - // set a new suffix via stdin - do { - another_line = console::readline(line, params.multiline_input); - buffer += line; - } while (another_line); - // check if we got an empty line - if (!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) { - params.input_suffix = buffer; - } - buffer.clear(); - // done taking input, reset color - console::set_display(console::reset); - - if (params.escape) { - //process escape sequences, for the initial prompt this is done in common.cpp when we load the params, but for the interactive mode we need to do it here - string_process_escapes(params.input_prefix); - string_process_escapes(params.input_suffix); - } - - // tokenize new prefix and suffix - std::vector inp_pfx = common_tokenize(ctx, params.input_prefix, false); - std::vector inp_sfx = common_tokenize(ctx, params.input_suffix, false); - - inp_pfx.insert(inp_pfx.begin(), llama_vocab_fim_pre(vocab)); - inp_sfx.insert(inp_sfx.begin(), llama_vocab_fim_suf(vocab)); - - embd_inp = params.spm_infill ? inp_sfx : inp_pfx; - embd_end = params.spm_infill ? inp_pfx : inp_sfx; - if (add_bos) { - embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab)); - } - embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); - - if (middle_token >= 0) { - embd_inp.push_back(middle_token); - } - - embd.clear(); - n_remain = params.n_predict; - n_past = 0; - n_consumed = 0; - is_interacting = false; - } - // deal with end of generation tokens in interactive mode - else if (llama_vocab_is_eog(vocab, common_sampler_last(smpl))) { - LOG_DBG("found EOS token\n"); - - if (params.interactive) { - - is_interacting = true; - LOG("\n"); - console::set_display(console::user_input); - } - } - - if (n_past > 0 && is_interacting && !params.interactive) { - LOG_DBG("waiting for user input\n"); - - if (params.input_prefix_bos) { - LOG_DBG("adding input prefix BOS token\n"); - embd_inp.push_back(llama_vocab_bos(vocab)); - } - - std::string buffer; - if (!params.input_prefix.empty()) { - LOG_DBG("appending input prefix: '%s'\n", params.input_prefix.c_str()); - buffer += params.input_prefix; - LOG("%s", buffer.c_str()); - } - - std::string line; - bool another_line = true; - do { - another_line = console::readline(line, params.multiline_input); - buffer += line; - } while (another_line); - - // done taking input, reset color - console::set_display(console::reset); - - // Add tokens to embd only if the input buffer is non-empty - // Entering a empty line lets the user pass control back - if (buffer.length() > 1) { - // append input suffix if any - if (!params.input_suffix.empty()) { - LOG_DBG("appending input suffix: '%s'\n", params.input_suffix.c_str()); - buffer += params.input_suffix; - LOG("%s", params.input_suffix.c_str()); - } - - LOG_DBG("buffer: '%s'\n", buffer.c_str()); - - const size_t original_size = embd_inp.size(); - - const auto line_inp = common_tokenize(ctx, buffer, false); - LOG_DBG("input tokens: %s\n", string_from(ctx, line_inp).c_str()); - - embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); - - for (size_t i = original_size; i < embd_inp.size(); ++i) { - const llama_token token = embd_inp[i]; - output_tokens.push_back(token); - output_ss << common_token_to_piece(ctx, token); - } - - n_remain -= line_inp.size(); - LOG_DBG("n_remain: %d\n", n_remain); - } else { - LOG_DBG("empty line, passing control back\n"); - } - - input_echo = false; // do not echo this again - } - - if (n_past > 0) { - if (is_interacting) { - common_sampler_reset(smpl); - } - is_interacting = false; - } - } - - // end of generation - if (!embd.empty() && llama_vocab_is_eog(vocab, embd.back()) && !params.interactive) { - break; - } - - // In interactive mode, respect the maximum number of tokens and drop back to user input when reached. - // We skip this logic when n_predict == -1 (infinite) or -2 (stop at context size). - if (params.interactive && n_remain <= 0 && params.n_predict >= 0) { - n_remain = params.n_predict; - is_interacting = true; - } - } - if (!params.interactive && n_remain <= 0) { - LOG("%s", common_token_to_piece(ctx, llama_vocab_eot(vocab)).c_str()); - } - - LOG("\n"); - common_perf_print(ctx, smpl); - - common_sampler_free(smpl); - llama_backend_free(); - - return 0; -} diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index 55f94c0b0..ed3795855 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -10,6 +10,9 @@ from typing import Any, List, Optional, Set, Tuple, Union def _build_repetition(item_rule, min_items, max_items, separator_rule=None): + if max_items == 0: + return "" + if min_items == 0 and max_items == 1: return f'{item_rule}?' diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 9654cd53c..711ddc5d1 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -194,7 +194,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( } batch->logits[batch->n_tokens - 1] = true; - llama_kv_self_clear(context); + llama_memory_clear(llama_get_memory(context), false); const auto t_pp_start = ggml_time_us(); if (llama_decode(context, *batch) != 0) { @@ -206,7 +206,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( LOGi("Benchmark text generation (tg)"); - llama_kv_self_clear(context); + llama_memory_clear(llama_get_memory(context), false); const auto t_tg_start = ggml_time_us(); for (i = 0; i < tg; i++) { @@ -223,7 +223,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( const auto t_tg_end = ggml_time_us(); - llama_kv_self_clear(context); + llama_memory_clear(llama_get_memory(context), false); const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0; const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0; @@ -448,5 +448,5 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( extern "C" JNIEXPORT void JNICALL Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) { - llama_kv_self_clear(reinterpret_cast(context)); + llama_memory_clear(llama_get_memory(reinterpret_cast(context)), true); } diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index f6e31abc9..dc2bafc88 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -210,7 +210,7 @@ actor LlamaContext { } batch.logits[Int(batch.n_tokens) - 1] = 1 // true - llama_kv_self_clear(context) + llama_memory_clear(llama_get_memory(context), false) let t_pp_start = DispatchTime.now().uptimeNanoseconds / 1000; @@ -223,7 +223,7 @@ actor LlamaContext { // bench text generation - llama_kv_self_clear(context) + llama_memory_clear(llama_get_memory(context), false) let t_tg_start = DispatchTime.now().uptimeNanoseconds / 1000; @@ -242,7 +242,7 @@ actor LlamaContext { let t_tg_end = DispatchTime.now().uptimeNanoseconds / 1000; - llama_kv_self_clear(context) + llama_memory_clear(llama_get_memory(context), false) let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0 let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0 @@ -292,7 +292,7 @@ actor LlamaContext { func clear() { tokens_list.removeAll() temporary_invalid_cchars.removeAll() - llama_kv_self_clear(context) + llama_memory_clear(llama_get_memory(context), true) } private func tokenize(text: String, add_bos: Bool) -> [llama_token] { diff --git a/examples/llava/CMakeLists.txt b/examples/llava/CMakeLists.txt deleted file mode 100644 index 2d5061de4..000000000 --- a/examples/llava/CMakeLists.txt +++ /dev/null @@ -1,97 +0,0 @@ -# llava (legacy) - -add_library(llava OBJECT - llava.cpp - llava.h - clip.cpp - clip.h - ) - -target_link_libraries(llava PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT}) - -target_include_directories(llava PUBLIC .) -target_include_directories(llava PUBLIC ../..) -target_include_directories(llava PUBLIC ../../common) - -target_compile_features(llava PRIVATE cxx_std_17) - -add_library(llava_static STATIC $) -if (BUILD_SHARED_LIBS) - set_target_properties(llava PROPERTIES POSITION_INDEPENDENT_CODE ON) - target_compile_definitions(llava PRIVATE LLAMA_SHARED LLAMA_BUILD) - add_library(llava_shared SHARED $) - target_link_libraries(llava_shared PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT}) - install(TARGETS llava_shared LIBRARY) -endif() - -# mtmd - -add_library(mtmd OBJECT - mtmd.cpp - mtmd.h - clip.cpp - clip.h - clip-impl.h - ) - -target_link_libraries(mtmd PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT}) - -target_include_directories(mtmd PUBLIC .) -target_include_directories(mtmd PRIVATE ../..) -target_include_directories(mtmd PRIVATE ../../common) # for stb_image.h - -target_compile_features(mtmd PRIVATE cxx_std_17) - -add_library(mtmd_static STATIC $) -if (BUILD_SHARED_LIBS) - set_target_properties(mtmd PROPERTIES POSITION_INDEPENDENT_CODE ON) - target_compile_definitions(mtmd PRIVATE LLAMA_SHARED LLAMA_BUILD) - add_library(mtmd_shared SHARED $) - target_link_libraries(mtmd_shared PRIVATE ggml llama ${CMAKE_THREAD_LIBS_INIT}) - install(TARGETS mtmd_shared LIBRARY) -endif() - -if (NOT MSVC) - target_compile_options(llava PRIVATE -Wno-cast-qual) # stb_image.h - target_compile_options(mtmd PRIVATE -Wno-cast-qual) # stb_image.h -endif() - -if(TARGET BUILD_INFO) - add_dependencies(llava BUILD_INFO) - add_dependencies(mtmd BUILD_INFO) -endif() - -set(TARGET llama-llava-cli) -add_executable(${TARGET} llava-cli.cpp) -set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-llava-cli) -install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_17) - -set(TARGET llama-minicpmv-cli) -add_executable(${TARGET} minicpmv-cli.cpp) -set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-minicpmv-cli) -install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_17) - -set(TARGET llama-qwen2vl-cli) -add_executable(${TARGET} qwen2vl-cli.cpp) -set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-qwen2vl-cli) -install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_17) - -set(TARGET llama-gemma3-cli) -add_executable(${TARGET} gemma3-cli.cpp) -set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-gemma3-cli) -install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE common mtmd ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_17) - -set(TARGET llama-llava-clip-quantize-cli) -add_executable(${TARGET} clip-quantize-cli.cpp) -set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-llava-clip-quantize-cli) -install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/llava/README-quantize.md b/examples/llava/README-quantize.md deleted file mode 100644 index b931513ab..000000000 --- a/examples/llava/README-quantize.md +++ /dev/null @@ -1,44 +0,0 @@ -# Quantizing CLIP Visual Projector - -This is the tool for quantizing the CLIP visual projector model. Quantization reduces the precision of the model's weights, which can significantly decrease the model size and improve inference speed, often with minimal impact on performance. - -## Usage - -To quantize a CLIP visual projector model, use the following command: - -```sh -./bin/llama-llava-clip-quantize-cli /path/to/ggml-model-f32.gguf /path/to/ggml-model-quantized.gguf -``` - -After the quantization, the visual projector can be used freely with the existing LLAVA cli (LLAVA, Qwen2VL, etc). - -### Arguments - -- `/path/to/ggml-model-f32.gguf`: The path to the input model file in FP32 or FP16 format. -- `/path/to/ggml-model-quantized.gguf`: The path where the quantized model will be saved. -- ``: The quantization type to apply. This should be an integer corresponding to one of the quantization types defined in the `enum ggml_type`. - -### Quantization Types - -The following quantization types are supported, based on the `enum ggml_type` definition: - -- `2` - `q4_0`: 4-bit quantization with a single scale value. -- `3` - `q4_1`: 4-bit quantization with a separate scale value for each block. -- `6` - `q5_0`: 5-bit quantization with a single scale value. -- `7` - `q5_1`: 5-bit quantization with a separate scale value for each block. -- `8` - `q8_0`: 8-bit quantization with a single scale value. - -### Example - -To quantize a model using the `q4_0` quantization type, you would run: - -```sh -./bin/llama-llava-clip-quantize-cli /path/to/ggml-model-f32.gguf /path/to/ggml-model-quantized.gguf 2 -``` - -This command will generate a quantized model at `/path/to/ggml-model-quantized.gguf` using the `q4_0` quantization method. - -## Notes - -- Quantization can lead to a loss in model accuracy, depending on the chosen quantization type. It is recommended to evaluate the quantized model's performance on your specific task to ensure it meets your requirements. -- The quantized model will typically be smaller in size and faster to run, making it more suitable for deployment in resource-constrained environments. diff --git a/examples/llava/android/adb_run.sh b/examples/llava/android/adb_run.sh deleted file mode 100755 index 45ccf8d70..000000000 --- a/examples/llava/android/adb_run.sh +++ /dev/null @@ -1,53 +0,0 @@ -#!/bin/bash - -model_dir="/Users/cxt/model/llm/mobileVLM/MobileVLM-1.7B_processed" -projector_name="mmproj-model-f16.gguf" -llama_name="ggml-model-q4_k.gguf" -img_dir="/Users/cxt/model/llm" -img_name="demo.jpg" -prompt="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: \nWho is the author of this book? \nAnswer the question using a single word or phrase. ASSISTANT:" -# img_name="cat.jpeg" -# prompt="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: \nWhat is in the image? ASSISTANT:" - -program_dir="build_64/bin" -binName="llama-llava-cli" -n_threads=4 - - -deviceDir="/data/local/tmp" -saveDir="output" -if [ ! -d ${saveDir} ]; then - mkdir ${saveDir} -fi - - -function android_run() { - # # copy resource into device - # adb push ${model_dir}/${projector_name} ${deviceDir}/${projector_name} - # adb push ${model_dir}/${llama_name} ${deviceDir}/${llama_name} - adb push ${img_dir}/${img_name} ${deviceDir}/${img_name} - # copy program into device - adb push ${program_dir}/${binName} ${deviceDir}/${binName} - adb shell "chmod 0777 ${deviceDir}/${binName}" - - # run - adb shell "echo cd ${deviceDir} ${deviceDir}/${binName} \ - -m ${deviceDir}/${llama_name} \ - --mmproj ${deviceDir}/${projector_name} \ - -t ${n_threads} \ - --image ${deviceDir}/${img_name} \ - -p \"${prompt}\" \ - > ${deviceDir}/${modelName}_${projector_name}_${n_threads}_${img_name}.txt" - adb shell "cd ${deviceDir}; pwd; ${deviceDir}/${binName} \ - -m ${deviceDir}/${llama_name} \ - --mmproj ${deviceDir}/${projector_name} \ - -t ${n_threads} \ - --image ${deviceDir}/${img_name} \ - -p \"${prompt}\" \ - >> ${deviceDir}/${modelName}_${projector_name}_${n_threads}_${img_name}.txt 2>&1" - adb pull ${deviceDir}/${modelName}_${projector_name}_${n_threads}_${img_name}.txt ${saveDir} -} - -android_run - -echo "android_run is Done!" diff --git a/examples/llava/android/build_64.sh b/examples/llava/android/build_64.sh deleted file mode 100755 index 71b6fd3f7..000000000 --- a/examples/llava/android/build_64.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -cmake ../../../../ \ --DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ --DCMAKE_BUILD_TYPE=Release \ --DANDROID_ABI="arm64-v8a" \ --DANDROID_PLATFORM=android-23 $1 - -make -j4 diff --git a/examples/llava/clip-quantize-cli.cpp b/examples/llava/clip-quantize-cli.cpp deleted file mode 100644 index 566506954..000000000 --- a/examples/llava/clip-quantize-cli.cpp +++ /dev/null @@ -1,59 +0,0 @@ -#include "arg.h" -#include "base64.hpp" -#include "log.h" -#include "common.h" -#include "sampling.h" -#include "clip.h" -#include "llava.h" -#include "llama.h" -#include "ggml.h" - -static void print_usage(int argc, char ** argv) { - (void) argc; - - fprintf(stderr, "usage: %s /path/to/ggml-model-f32.gguf /path/to/ggml-model-quantized.gguf type\n", argv[0]); - fprintf(stderr, " type = 2 - q4_0\n"); - fprintf(stderr, " type = 3 - q4_1\n"); - fprintf(stderr, " type = 6 - q5_0\n"); - fprintf(stderr, " type = 7 - q5_1\n"); - fprintf(stderr, " type = 8 - q8_0\n"); -} - -int main(int argc, char ** argv) { - if (argc != 4) { - print_usage(argc, argv); - return 1; - } - - const std::string fname_inp = argv[1]; - const std::string fname_out = argv[2]; - - const int itype = atoi(argv[3]); - - const int64_t t_main_start_us = ggml_time_us(); - - int64_t t_quantize_us = 0; - - // load the model - { - const int64_t t_start_us = ggml_time_us(); - - if (!clip_model_quantize(fname_inp.c_str(), fname_out.c_str(), itype)) { - fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str()); - return 1; - } - - t_quantize_us = ggml_time_us() - t_start_us; - } - - // report timing - { - const int64_t t_main_end_us = ggml_time_us(); - - printf("\n"); - printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us / 1000.0f); - printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f); - } - - return 0; -} diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp deleted file mode 100644 index 49c90b750..000000000 --- a/examples/llava/clip.cpp +++ /dev/null @@ -1,2882 +0,0 @@ -// NOTE: This is modified from clip.cpp only for LLaVA, -// so there might be still unnecessary artifacts hanging around -// I'll gradually clean and extend it -// Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch -#include "clip.h" -#include "clip-impl.h" -#include "ggml.h" -#include "ggml-cpp.h" -#include "ggml-cpu.h" -#include "ggml-alloc.h" -#include "ggml-backend.h" -#include "gguf.h" - -#define STB_IMAGE_IMPLEMENTATION -#include "stb_image.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL}; - -//#define CLIP_DEBUG_FUNCTIONS - -#ifdef CLIP_DEBUG_FUNCTIONS -static void clip_image_write_image_to_ppm(const clip_image_u8& img, const std::string& filename) { - std::ofstream file(filename, std::ios::binary); - if (!file.is_open()) { - LOG_ERR("Failed to open file for writing: %s\n", filename.c_str()); - return; - } - - // PPM header: P6 format, width, height, and max color value - file << "P6\n" << img.nx << " " << img.ny << "\n255\n"; - - // Write pixel data - for (size_t i = 0; i < img.buf.size(); i += 3) { - // PPM expects binary data in RGB format, which matches our image buffer - file.write(reinterpret_cast(&img.buf[i]), 3); - } - - file.close(); -} - -static void clip_image_save_to_bmp(const clip_image_u8& img, const std::string& filename) { - std::ofstream file(filename, std::ios::binary); - if (!file.is_open()) { - LOG_ERR("Failed to open file for writing: %s\n", filename.c_str()); - return; - } - - int fileSize = 54 + 3 * img.nx * img.ny; // File header + info header + pixel data - int bytesPerPixel = 3; - int widthInBytes = img.nx * bytesPerPixel; - int paddingAmount = (4 - (widthInBytes % 4)) % 4; - int stride = widthInBytes + paddingAmount; - - // Bitmap file header - unsigned char fileHeader[14] = { - 'B','M', // Signature - 0,0,0,0, // Image file size in bytes - 0,0,0,0, // Reserved - 54,0,0,0 // Start of pixel array - }; - - // Total file size - fileSize = 54 + (stride * img.ny); - fileHeader[2] = (unsigned char)(fileSize); - fileHeader[3] = (unsigned char)(fileSize >> 8); - fileHeader[4] = (unsigned char)(fileSize >> 16); - fileHeader[5] = (unsigned char)(fileSize >> 24); - - // Bitmap information header (BITMAPINFOHEADER) - unsigned char infoHeader[40] = { - 40,0,0,0, // Size of this header (40 bytes) - 0,0,0,0, // Image width - 0,0,0,0, // Image height - 1,0, // Number of color planes - 24,0, // Bits per pixel - 0,0,0,0, // No compression - 0,0,0,0, // Image size (can be 0 for no compression) - 0,0,0,0, // X pixels per meter (not specified) - 0,0,0,0, // Y pixels per meter (not specified) - 0,0,0,0, // Total colors (color table not used) - 0,0,0,0 // Important colors (all are important) - }; - - // Width and height in the information header - infoHeader[4] = (unsigned char)(img.nx); - infoHeader[5] = (unsigned char)(img.nx >> 8); - infoHeader[6] = (unsigned char)(img.nx >> 16); - infoHeader[7] = (unsigned char)(img.nx >> 24); - infoHeader[8] = (unsigned char)(img.ny); - infoHeader[9] = (unsigned char)(img.ny >> 8); - infoHeader[10] = (unsigned char)(img.ny >> 16); - infoHeader[11] = (unsigned char)(img.ny >> 24); - - // Write file headers - file.write(reinterpret_cast(fileHeader), sizeof(fileHeader)); - file.write(reinterpret_cast(infoHeader), sizeof(infoHeader)); - - // Pixel data - std::vector padding(3, 0); // Max padding size to be added to each row - for (int y = img.ny - 1; y >= 0; --y) { // BMP files are stored bottom-to-top - for (int x = 0; x < img.nx; ++x) { - // Each pixel - size_t pixelIndex = (y * img.nx + x) * 3; - unsigned char pixel[3] = { - img.buf[pixelIndex + 2], // BMP stores pixels in BGR format - img.buf[pixelIndex + 1], - img.buf[pixelIndex] - }; - file.write(reinterpret_cast(pixel), 3); - } - // Write padding for the row - file.write(reinterpret_cast(padding.data()), paddingAmount); - } - - file.close(); -} - -// debug function to convert f32 to u8 -static void clip_image_convert_f32_to_u8(const clip_image_f32& src, clip_image_u8& dst) { - dst.nx = src.nx; - dst.ny = src.ny; - dst.buf.resize(3 * src.nx * src.ny); - for (size_t i = 0; i < src.buf.size(); ++i) { - dst.buf[i] = static_cast(std::min(std::max(int(src.buf[i] * 255.0f), 0), 255)); - } -} -#endif - - -// -// clip layers -// - -enum patch_merge_type { - PATCH_MERGE_FLAT, - PATCH_MERGE_SPATIAL_UNPAD, -}; - -struct clip_hparams { - int32_t image_size; - int32_t patch_size; - int32_t hidden_size; - int32_t n_intermediate; - int32_t projection_dim; - int32_t n_head; - int32_t n_layer; - - patch_merge_type mm_patch_merge_type = PATCH_MERGE_FLAT; - - float eps; - - std::vector image_grid_pinpoints; - int32_t image_crop_resolution; - std::unordered_set vision_feature_layer; -}; - -struct clip_layer { - // attention - struct ggml_tensor * k_w = nullptr; - struct ggml_tensor * k_b = nullptr; - struct ggml_tensor * q_w = nullptr; - struct ggml_tensor * q_b = nullptr; - struct ggml_tensor * v_w = nullptr; - struct ggml_tensor * v_b = nullptr; - - struct ggml_tensor * o_w = nullptr; - struct ggml_tensor * o_b = nullptr; - - // layernorm 1 - struct ggml_tensor * ln_1_w = nullptr; - struct ggml_tensor * ln_1_b = nullptr; - - // ff - struct ggml_tensor * ff_i_w = nullptr; - struct ggml_tensor * ff_i_b = nullptr; - - struct ggml_tensor * ff_o_w = nullptr; - struct ggml_tensor * ff_o_b = nullptr; - - // layernorm 2 - struct ggml_tensor * ln_2_w = nullptr; - struct ggml_tensor * ln_2_b = nullptr; -}; - -struct clip_vision_model { - struct clip_hparams hparams; - - // embeddings - struct ggml_tensor * class_embedding = nullptr; - struct ggml_tensor * patch_embeddings_0 = nullptr; - struct ggml_tensor * patch_embeddings_1 = nullptr; // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL) - struct ggml_tensor * patch_bias = nullptr; - struct ggml_tensor * position_embeddings = nullptr; - - struct ggml_tensor * pre_ln_w = nullptr; - struct ggml_tensor * pre_ln_b = nullptr; - - std::vector layers; - - struct ggml_tensor * post_ln_w; - struct ggml_tensor * post_ln_b; - - struct ggml_tensor * projection; - - // LLaVA projection - struct ggml_tensor * mm_0_w = nullptr; - struct ggml_tensor * mm_0_b = nullptr; - struct ggml_tensor * mm_2_w = nullptr; - struct ggml_tensor * mm_2_b = nullptr; - - struct ggml_tensor * image_newline = nullptr; - - // Yi type models with mlp+normalization projection - struct ggml_tensor * mm_1_w = nullptr; // Yi type models have 0, 1, 3, 4 - struct ggml_tensor * mm_1_b = nullptr; - struct ggml_tensor * mm_3_w = nullptr; - struct ggml_tensor * mm_3_b = nullptr; - struct ggml_tensor * mm_4_w = nullptr; - struct ggml_tensor * mm_4_b = nullptr; - - //GLMV-Edge projection - struct ggml_tensor * mm_model_adapter_conv_w = nullptr; - struct ggml_tensor * mm_model_adapter_conv_b = nullptr; - struct ggml_tensor * boi_w = nullptr; - struct ggml_tensor * eoi_w = nullptr; - - // MobileVLM projection - struct ggml_tensor * mm_model_mlp_1_w = nullptr; - struct ggml_tensor * mm_model_mlp_1_b = nullptr; - struct ggml_tensor * mm_model_mlp_3_w = nullptr; - struct ggml_tensor * mm_model_mlp_3_b = nullptr; - struct ggml_tensor * mm_model_block_1_block_0_0_w = nullptr; - struct ggml_tensor * mm_model_block_1_block_0_1_w = nullptr; - struct ggml_tensor * mm_model_block_1_block_0_1_b = nullptr; - struct ggml_tensor * mm_model_block_1_block_1_fc1_w = nullptr; - struct ggml_tensor * mm_model_block_1_block_1_fc1_b = nullptr; - struct ggml_tensor * mm_model_block_1_block_1_fc2_w = nullptr; - struct ggml_tensor * mm_model_block_1_block_1_fc2_b = nullptr; - struct ggml_tensor * mm_model_block_1_block_2_0_w = nullptr; - struct ggml_tensor * mm_model_block_1_block_2_1_w = nullptr; - struct ggml_tensor * mm_model_block_1_block_2_1_b = nullptr; - struct ggml_tensor * mm_model_block_2_block_0_0_w = nullptr; - struct ggml_tensor * mm_model_block_2_block_0_1_w = nullptr; - struct ggml_tensor * mm_model_block_2_block_0_1_b = nullptr; - struct ggml_tensor * mm_model_block_2_block_1_fc1_w = nullptr; - struct ggml_tensor * mm_model_block_2_block_1_fc1_b = nullptr; - struct ggml_tensor * mm_model_block_2_block_1_fc2_w = nullptr; - struct ggml_tensor * mm_model_block_2_block_1_fc2_b = nullptr; - struct ggml_tensor * mm_model_block_2_block_2_0_w = nullptr; - struct ggml_tensor * mm_model_block_2_block_2_1_w = nullptr; - struct ggml_tensor * mm_model_block_2_block_2_1_b = nullptr; - - // MobileVLM_V2 projection - struct ggml_tensor * mm_model_mlp_0_w = nullptr; - struct ggml_tensor * mm_model_mlp_0_b = nullptr; - struct ggml_tensor * mm_model_mlp_2_w = nullptr; - struct ggml_tensor * mm_model_mlp_2_b = nullptr; - struct ggml_tensor * mm_model_peg_0_w = nullptr; - struct ggml_tensor * mm_model_peg_0_b = nullptr; - - // MINICPMV projection - struct ggml_tensor * mm_model_pos_embed_k = nullptr; - struct ggml_tensor * mm_model_query = nullptr; - struct ggml_tensor * mm_model_proj = nullptr; - struct ggml_tensor * mm_model_kv_proj = nullptr; - struct ggml_tensor * mm_model_attn_q_w = nullptr; - struct ggml_tensor * mm_model_attn_q_b = nullptr; - struct ggml_tensor * mm_model_attn_k_w = nullptr; - struct ggml_tensor * mm_model_attn_k_b = nullptr; - struct ggml_tensor * mm_model_attn_v_w = nullptr; - struct ggml_tensor * mm_model_attn_v_b = nullptr; - struct ggml_tensor * mm_model_attn_o_w = nullptr; - struct ggml_tensor * mm_model_attn_o_b = nullptr; - struct ggml_tensor * mm_model_ln_q_w = nullptr; - struct ggml_tensor * mm_model_ln_q_b = nullptr; - struct ggml_tensor * mm_model_ln_kv_w = nullptr; - struct ggml_tensor * mm_model_ln_kv_b = nullptr; - struct ggml_tensor * mm_model_ln_post_w = nullptr; - struct ggml_tensor * mm_model_ln_post_b = nullptr; - - // gemma3 - struct ggml_tensor * mm_input_proj_w = nullptr; - struct ggml_tensor * mm_soft_emb_norm_w = nullptr; -}; - -struct clip_ctx { - bool has_text_encoder = false; - bool has_vision_encoder = false; - bool has_llava_projector = false; - bool has_minicpmv_projector = false; - bool has_glm_projector = false; - bool has_qwen2vl_merger = false; - int minicpmv_version = 2; - - struct clip_vision_model vision_model; - projector_type proj_type = PROJECTOR_TYPE_MLP; - - int32_t max_feature_layer; // unused in newer models like gemma3 - float image_mean[3]; - float image_std[3]; - bool use_gelu = false; - bool use_silu = false; - - gguf_context_ptr ctx_gguf; - ggml_context_ptr ctx_data; - - std::vector buf_compute_meta; - - std::vector backend_ptrs; - std::vector backend_buft; - - ggml_backend_t backend; - ggml_backend_t backend_cpu; - ggml_backend_buffer_ptr buf; - - ggml_backend_sched_ptr sched; - - clip_image_size load_image_size; - - clip_ctx(clip_context_params & ctx_params) { - backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); - backend = ctx_params.use_gpu - ? ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_GPU, nullptr) - : nullptr; - - if (backend) { - LOG_INF("%s: CLIP using %s backend\n", __func__, ggml_backend_name(backend)); - backend_ptrs.push_back(backend); - backend_buft.push_back(ggml_backend_get_default_buffer_type(backend)); - } else { - backend = backend_cpu; - LOG_INF("%s: CLIP using CPU backend\n", __func__); - } - - backend_ptrs.push_back(backend_cpu); - backend_buft.push_back(ggml_backend_get_default_buffer_type(backend_cpu)); - - sched.reset( - ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), 8192, false) - ); - } - - ~clip_ctx() { - ggml_backend_free(backend); - if (backend != backend_cpu) { - ggml_backend_free(backend_cpu); - } - } -}; - -static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32_batch & imgs) { - const auto & model = ctx->vision_model; - const auto & hparams = model.hparams; - - const int image_size = hparams.image_size; - int image_size_width = image_size; - int image_size_height = image_size; - - const int patch_size = hparams.patch_size; - const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); - const int hidden_size = hparams.hidden_size; - const int n_head = hparams.n_head; - const int d_head = hidden_size / n_head; - const int n_layer = hparams.n_layer; - const float eps = hparams.eps; - - GGML_ASSERT(imgs.entries.size() == 1); // batch_size == 1 - - struct ggml_init_params params = { - /*.mem_size =*/ ctx->buf_compute_meta.size(), - /*.mem_buffer =*/ ctx->buf_compute_meta.data(), - /*.no_alloc =*/ true, - }; - - ggml_context_ptr ctx0_ptr(ggml_init(params)); - auto ctx0 = ctx0_ptr.get(); - - struct ggml_cgraph * gf = ggml_new_graph(ctx0); - - // input raw - struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3); - ggml_set_name(inp_raw, "inp_raw"); - ggml_set_input(inp_raw); - - struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); - inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size); - inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); - inp = ggml_add(ctx0, inp, model.patch_bias); - - // position embeddings - struct ggml_tensor * embeddings = ggml_add(ctx0, inp, model.position_embeddings); - - // loop over layers - for (int il = 0; il < n_layer; il++) { - struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states - - // layernorm1 - { - cur = ggml_norm(ctx0, cur, eps); - cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w), model.layers[il].ln_1_b); - } - - // self-attention - { - - struct ggml_tensor * Q = - ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b); - - Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches); - Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); - - struct ggml_tensor * K = - ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b); - - K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches); - K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); - - struct ggml_tensor * V = - ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b); - - V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches); - V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); - - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); - - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); - KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head); - KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - - cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches); - } - - // attention output - cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b); - - // re-add the layer input, e.g., residual - cur = ggml_add(ctx0, cur, embeddings); - - embeddings = cur; // embeddings = residual, cur = hidden_states - - // layernorm2 - { - cur = ggml_norm(ctx0, cur, eps); - cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b); - } - - cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); - cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b); - - // siglip uses gelu - cur = ggml_gelu(ctx0, cur); - - cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); - cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b); - - // residual 2 - cur = ggml_add(ctx0, embeddings, cur); - - embeddings = cur; - } - - // post-layernorm - if (model.post_ln_w) { - embeddings = ggml_norm(ctx0, embeddings, eps); - ggml_set_name(embeddings, "post_ln"); - - embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b); - } - - if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { - const int batch_size = 1; - const int mm_tokens_per_image = 256; // default value for gemma3 - const int tokens_per_side = sqrt(mm_tokens_per_image); - const int patches_per_image = sqrt(num_patches); - const int kernel_size = patches_per_image / tokens_per_side; - - embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings)); - embeddings = ggml_reshape_4d(ctx0, embeddings, patches_per_image, patches_per_image, hidden_size, batch_size); - - // doing a pool2d to reduce the number of output tokens to 256 - embeddings = ggml_pool_2d(ctx0, embeddings, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0); - embeddings = ggml_reshape_3d(ctx0, embeddings, embeddings->ne[0] * embeddings->ne[0], hidden_size, batch_size); - embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings)); - - // apply norm before projection - embeddings = ggml_rms_norm(ctx0, embeddings, eps); - embeddings = ggml_mul(ctx0, embeddings, model.mm_soft_emb_norm_w); - - // apply projection - embeddings = ggml_mul_mat(ctx0, - ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)), - embeddings); - } - - // build the graph - ggml_build_forward_expand(gf, embeddings); - - return gf; -} - -static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) { - if (!ctx->has_vision_encoder) { - LOG_ERR("This gguf file seems to have no vision encoder\n"); - return nullptr; - } - - const auto & model = ctx->vision_model; - const auto & hparams = model.hparams; - - const int image_size = hparams.image_size; - int image_size_width = image_size; - int image_size_height = image_size; - if (ctx->has_minicpmv_projector) { - LOG_DBG("%s: %d %d\n", __func__, load_image_size.width, load_image_size.height); - image_size_width = load_image_size.width; - image_size_height = load_image_size.height; - if (is_inf) { - image_size_width = imgs.entries[0]->nx; - image_size_height = imgs.entries[0]->ny; - } - } - else if (ctx->has_qwen2vl_merger) { - // use the image's native resolution when image is avaible - if (is_inf) { - // if (imgs->data->nx && imgs->data->ny) { - image_size_width = imgs.entries[0]->nx; - image_size_height = imgs.entries[0]->ny; - } - } - const int patch_size = hparams.patch_size; - const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); - const int patches_w = image_size_width / patch_size; - const int patches_h = image_size_height / patch_size; - const int num_positions = num_patches + (model.class_embedding ? 1 : 0); - const int num_position_ids = ctx->has_qwen2vl_merger ? num_positions * 4 : num_positions; - const int hidden_size = hparams.hidden_size; - const int n_head = hparams.n_head; - const int d_head = hidden_size / n_head; - const float eps = hparams.eps; - int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; - - const int batch_size = imgs.entries.size(); - - if (ctx->has_llava_projector || ctx->has_minicpmv_projector || ctx->has_glm_projector) { - GGML_ASSERT(batch_size == 1); - } - - struct ggml_init_params params = { - /*.mem_size =*/ ctx->buf_compute_meta.size(), - /*.mem_buffer =*/ ctx->buf_compute_meta.data(), - /*.no_alloc =*/ true, - }; - - ggml_context_ptr ctx0_ptr(ggml_init(params)); - auto ctx0 = ctx0_ptr.get(); - - struct ggml_cgraph * gf = ggml_new_graph(ctx0); - - struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size); - ggml_set_name(inp_raw, "inp_raw"); - ggml_set_input(inp_raw); - - struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); - - if (ctx->has_qwen2vl_merger) { - GGML_ASSERT(image_size_width % (patch_size * 2) == 0); - GGML_ASSERT(image_size_height % (patch_size * 2) == 0); - - auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1); - inp = ggml_add(ctx0, inp, inp_1); - inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3)); // [w, h, c, b] -> [c, w, h, b] - inp = ggml_reshape_4d( - ctx0, inp, - hidden_size * 2, patches_w / 2, patches_h, batch_size); - inp = ggml_reshape_4d( - ctx0, inp, - hidden_size * 2, patches_w / 2, 2, batch_size * (patches_h / 2)); - inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3)); - inp = ggml_reshape_3d( - ctx0, inp, - hidden_size, patches_w * patches_h, batch_size); - } - else { - inp = ggml_reshape_3d(ctx0, inp, num_patches, hidden_size, batch_size); - inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); - } - - if (model.patch_bias) { - // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp)); - inp = ggml_add(ctx0, inp, model.patch_bias); - } - struct ggml_tensor * embeddings = inp; - struct ggml_tensor * pos_embed = nullptr; - - if (ctx->has_llava_projector) { - // concat class_embeddings and patch_embeddings - if (model.class_embedding) { - embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size); - ggml_set_name(embeddings, "embeddings"); - ggml_set_input(embeddings); - embeddings = ggml_acc(ctx0, embeddings, model.class_embedding, - embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0); - embeddings = ggml_acc(ctx0, embeddings, inp, - embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]); - } - } - - struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids); - ggml_set_name(positions, "positions"); - ggml_set_input(positions); - - if (!ctx->has_qwen2vl_merger) { // qwen2vl use rope position embedding - embeddings = - ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions)); - } - - if (ctx->has_minicpmv_projector) { - int pos_w = image_size_width/patch_size; - int pos_h = image_size_height/patch_size; - if (ctx->minicpmv_version == 2) { - pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1); - } - else if (ctx->minicpmv_version == 3) { - pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1); - } - else if (ctx->minicpmv_version == 4) { - pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1); - } - ggml_set_name(pos_embed, "pos_embed"); - ggml_set_input(pos_embed); - } - - // pre-layernorm - if (model.pre_ln_w) { - embeddings = ggml_norm(ctx0, embeddings, eps); - ggml_set_name(embeddings, "pre_ln"); - - embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b); - } - - std::vector embedding_stack; - const auto & vision_feature_layer = hparams.vision_feature_layer; - - // loop over layers - for (int il = 0; il < ctx->max_feature_layer; il++) { - struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states - - // If this is an embedding feature layer, save the output. - // NOTE: 0 index here refers to the input to the encoder. - if (vision_feature_layer.find(il) != vision_feature_layer.end()) { - embedding_stack.push_back(embeddings); - } - - //const size_t nb_q_w = model.layers[il].q_w->nb[0]; - - // layernorm1 - { - cur = ggml_norm(ctx0, cur, eps); - - cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w), - model.layers[il].ln_1_b); - } - - // self-attention - { - - struct ggml_tensor * Q = - ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b); - - Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size); - if (ctx->has_qwen2vl_merger) { - Q = ggml_rope_multi( - ctx0, Q, positions, nullptr, - d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); - } - Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); - Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size); - - struct ggml_tensor * K = - ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b); - - K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); - if (ctx->has_qwen2vl_merger) { - K = ggml_rope_multi( - ctx0, K, positions, nullptr, - d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); - } - K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); - K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size); - - struct ggml_tensor * V = - ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b); - - V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size); - V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); - V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); - - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); - KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size); - KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - - cur = ggml_cont_3d(ctx0, KQV, hidden_size, num_positions, batch_size); - } - - // attention output - cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b); - - // re-add the layer input, e.g., residual - cur = ggml_add(ctx0, cur, embeddings); - - embeddings = cur; // embeddings = residual, cur = hidden_states - - // layernorm2 - { - cur = ggml_norm(ctx0, cur, eps); - - cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b); - } - - cur = ggml_mul_mat(ctx0, model.layers[il].ff_i_w, cur); - cur = ggml_add(ctx0, cur, model.layers[il].ff_i_b); - - if (ctx->use_gelu) { - cur = ggml_gelu_inplace(ctx0, cur); - } else if (ctx->use_silu) { - cur = ggml_silu_inplace(ctx0, cur); - } else { - cur = ggml_gelu_quick_inplace(ctx0, cur); - } - - cur = ggml_mul_mat(ctx0, model.layers[il].ff_o_w, cur); - cur = ggml_add(ctx0, cur, model.layers[il].ff_o_b); - - // residual 2 - cur = ggml_add(ctx0, embeddings, cur); - - embeddings = cur; - } - - // post-layernorm - if (model.post_ln_w) { - embeddings = ggml_norm(ctx0, embeddings, eps); - ggml_set_name(embeddings, "post_ln"); - - embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b); - } - - // final layer is a vision feature layer - if (vision_feature_layer.find(ctx->max_feature_layer) != vision_feature_layer.end()) { - embedding_stack.push_back(embeddings); - } - - // If feature layers are explicitly set, stack them (if we have multiple) - if (!embedding_stack.empty()) { - embeddings = embedding_stack[0]; - for (size_t i = 1; i < embedding_stack.size(); i++) { - embeddings = ggml_concat(ctx0, embeddings, embedding_stack[i], 0); - } - } - - // llava projector - if (ctx->has_llava_projector) { - embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); - - struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches); - ggml_set_name(patches, "patches"); - ggml_set_input(patches); - - // shape [1, 576, 1024] - // ne is whcn, ne = [1024, 576, 1, 1] - embeddings = ggml_get_rows(ctx0, embeddings, patches); - - // print_tensor_info(embeddings, "embeddings"); - - // llava projector - if (ctx->proj_type == PROJECTOR_TYPE_MLP) { - embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); - embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); - - embeddings = ggml_gelu(ctx0, embeddings); - if (model.mm_2_w) { - embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); - embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); - } - } - else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { - embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); - embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); - // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false); - // First LayerNorm - embeddings = ggml_norm(ctx0, embeddings, eps); - embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_1_w), - model.mm_1_b); - - // GELU activation - embeddings = ggml_gelu(ctx0, embeddings); - - // Second linear layer - embeddings = ggml_mul_mat(ctx0, model.mm_3_w, embeddings); - embeddings = ggml_add(ctx0, embeddings, model.mm_3_b); - - // Second LayerNorm - embeddings = ggml_norm(ctx0, embeddings, eps); - embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_4_w), - model.mm_4_b); - } - else if (ctx->proj_type == PROJECTOR_TYPE_LDP) { - // MobileVLM projector - int n_patch = 24; - struct ggml_tensor * mlp_1 = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, embeddings); - mlp_1 = ggml_add(ctx0, mlp_1, model.mm_model_mlp_1_b); - mlp_1 = ggml_gelu(ctx0, mlp_1); - struct ggml_tensor * mlp_3 = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, mlp_1); - mlp_3 = ggml_add(ctx0, mlp_3, model.mm_model_mlp_3_b); - // mlp_3 shape = [1, 576, 2048], ne = [2048, 576, 1, 1] - - // block 1 - struct ggml_tensor * block_1 = nullptr; - { - // transpose from [1, 576, 2048] --> [1, 2048, 576] --> [1, 2048, 24, 24] - mlp_3 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_3, 1, 0, 2, 3)); - mlp_3 = ggml_reshape_4d(ctx0, mlp_3, n_patch, n_patch, mlp_3->ne[1], mlp_3->ne[2]); - // stride = 1, padding = 1, bias is nullptr - block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_1_block_0_0_w, mlp_3, 1, 1, 1, 1, 1, 1); - - // layer norm - // // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1] - block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 2, 0, 3)); - // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1] - block_1 = ggml_norm(ctx0, block_1, eps); - block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_1_block_0_1_w), model.mm_model_block_1_block_0_1_b); - block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3)); - - // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1] - // hardswish - struct ggml_tensor * block_1_hw = ggml_hardswish(ctx0, block_1); - - block_1 = ggml_pool_2d(ctx0, block_1_hw, GGML_OP_POOL_AVG, block_1_hw->ne[0], block_1_hw->ne[1], block_1_hw->ne[0], block_1_hw->ne[1], 0, 0); - // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1] - // pointwise conv - block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]); - block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc1_w, block_1); - block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc1_b); - block_1 = ggml_relu(ctx0, block_1); - block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc2_w, block_1); - block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc2_b); - block_1 = ggml_hardsigmoid(ctx0, block_1); - // block_1_hw shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1], block_1 shape = [1, 2048], ne = [2048, 1, 1, 1] - block_1 = ggml_reshape_4d(ctx0, block_1, 1, 1, block_1->ne[0], block_1->ne[1]); - block_1 = ggml_mul(ctx0, block_1_hw, block_1); - - int w = block_1->ne[0], h = block_1->ne[1]; - block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]); - block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3)); - - // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1] - block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_2_0_w, block_1); - block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]); - - // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1] - block_1 = ggml_norm(ctx0, block_1, eps); - block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_1_block_2_1_w), model.mm_model_block_1_block_2_1_b); - block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3)); - // block1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1] - // residual - block_1 = ggml_add(ctx0, mlp_3, block_1); - } - - // block_2 - { - // stride = 2 - block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_2_block_0_0_w, block_1, 2, 2, 1, 1, 1, 1); - - // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1] - // layer norm - block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 2, 0, 3)); - // block_1 shape = [1, 12, 12, 2048], ne = [2048, 12, 12, 1] - block_1 = ggml_norm(ctx0, block_1, eps); - block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_2_block_0_1_w), model.mm_model_block_2_block_0_1_b); - block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3)); - // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1] - // hardswish - struct ggml_tensor * block_1_hw = ggml_hardswish(ctx0, block_1); - - // not sure the parameters is right for globalAvgPooling - block_1 = ggml_pool_2d(ctx0, block_1_hw, GGML_OP_POOL_AVG, block_1_hw->ne[0], block_1_hw->ne[1], block_1_hw->ne[0], block_1_hw->ne[1], 0, 0); - // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1] - // pointwise conv - block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]); - block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc1_w, block_1); - block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc1_b); - block_1 = ggml_relu(ctx0, block_1); - block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc2_w, block_1); - block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc2_b); - block_1 = ggml_hardsigmoid(ctx0, block_1); - - // block_1_hw shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1], block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1] - block_1 = ggml_reshape_4d(ctx0, block_1, 1, 1, block_1->ne[0], block_1->ne[1]); - block_1 = ggml_mul(ctx0, block_1_hw, block_1); - - int w = block_1->ne[0], h = block_1->ne[1]; - block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]); - block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3)); - // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1] - block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_2_0_w, block_1); - block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]); - - - // block_1 shape = [1, 12, 12, 2048], ne = [2048, 12, 12, 1] - block_1 = ggml_norm(ctx0, block_1, eps); - block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_2_block_2_1_w), model.mm_model_block_2_block_2_1_b); - block_1 = ggml_reshape_3d(ctx0, block_1, block_1->ne[0], block_1->ne[1] * block_1->ne[2], block_1->ne[3]); - // block_1 shape = [1, 144, 2048], ne = [2048, 144, 1] - } - embeddings = block_1; - } - else if (ctx->proj_type == PROJECTOR_TYPE_LDPV2) - { - int n_patch = 24; - struct ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings); - mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b); - mlp_0 = ggml_gelu(ctx0, mlp_0); - struct ggml_tensor * mlp_2 = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, mlp_0); - mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b); - // mlp_2 ne = [2048, 576, 1, 1] - // // AVG Pool Layer 2*2, strides = 2 - mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 0, 2, 3)); - // mlp_2 ne = [576, 2048, 1, 1] - mlp_2 = ggml_reshape_4d(ctx0, mlp_2, n_patch, n_patch, mlp_2->ne[1], mlp_2->ne[2]); - // mlp_2 ne [24, 24, 2048, 1] - mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0); - // weight ne = [3, 3, 2048, 1] - struct ggml_tensor * peg_0 = ggml_conv_2d_dw(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1); - peg_0 = ggml_cont(ctx0, ggml_permute(ctx0, peg_0, 1, 2, 0, 3)); - peg_0 = ggml_add(ctx0, peg_0, model.mm_model_peg_0_b); - mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 2, 0, 3)); - peg_0 = ggml_add(ctx0, peg_0, mlp_2); - peg_0 = ggml_reshape_3d(ctx0, peg_0, peg_0->ne[0], peg_0->ne[1] * peg_0->ne[2], peg_0->ne[3]); - embeddings = peg_0; - } - else { - GGML_ABORT("fatal error"); - } - } - // minicpmv projector - else if (ctx->has_minicpmv_projector) - { - if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { - struct ggml_tensor * q = model.mm_model_query; - { // layernorm - q = ggml_norm(ctx0, q, eps); - q = ggml_add(ctx0, ggml_mul(ctx0, q, model.mm_model_ln_q_w), model.mm_model_ln_q_b); - } - struct ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings); - { // layernorm - v = ggml_norm(ctx0, v, eps); - v = ggml_add(ctx0, ggml_mul(ctx0, v, model.mm_model_ln_kv_w), model.mm_model_ln_kv_b); - } - struct ggml_tensor * k; - { // position - // q = ggml_add(ctx0, q, model.mm_model_pos_embed); - k = ggml_add(ctx0, v, pos_embed); - } - - { // attention - int hidden_size = 4096; - const int d_head = 128; - int n_head = hidden_size/d_head; - int num_query = 96; - if (ctx->minicpmv_version == 2) { - hidden_size = 4096; - n_head = hidden_size/d_head; - num_query = 96; - } - else if (ctx->minicpmv_version == 3) { - hidden_size = 3584; - n_head = hidden_size/d_head; - num_query = 64; - } - else if (ctx->minicpmv_version == 4) { - hidden_size = 3584; - n_head = hidden_size/d_head; - num_query = 64; - } - - struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b); - struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b); - struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b); - // permute - Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_query, batch_size); - Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); - Q = ggml_reshape_3d(ctx0, Q, d_head, num_query, n_head * batch_size); - K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size); - K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); - K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size); - V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size); - V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3)); - V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size); - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f); - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ); - KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size); - KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3); - KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size); - - embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b); - } - { // layernorm - embeddings = ggml_norm(ctx0, embeddings, eps); - embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_post_w), model.mm_model_ln_post_b); - } - embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings); - } - else { - GGML_ASSERT(false); - } - } - // glm projector - else if (ctx->has_glm_projector) { - if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) { - size_t gridsz = (size_t)sqrt(embeddings->ne[1]); - embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3)); - embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]); - embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1); - embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size); - embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3)); - embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b); - //GLU - { - embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings); - embeddings = ggml_norm(ctx0, embeddings, eps); - embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b); - embeddings = ggml_gelu_inplace(ctx0, embeddings); - struct ggml_tensor * x = embeddings; - embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings); - x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x); - embeddings = ggml_silu_inplace(ctx0, embeddings); - embeddings = ggml_mul(ctx0, embeddings,x); - embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings); - } - } else { - GGML_ABORT("fatal error"); - } - } - else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { - embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size); - - embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); - embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); - - // GELU activation - embeddings = ggml_gelu(ctx0, embeddings); - - // Second linear layer - embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings); - embeddings = ggml_add(ctx0, embeddings, model.mm_1_b); - } - - // build the graph - ggml_build_forward_expand(gf, embeddings); - - return gf; -} - -static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) { - if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { - return clip_image_build_graph_siglip(ctx, imgs); - } else { - // TODO: we should have one build_* function per model - return clip_image_build_graph_legacy(ctx, imgs, load_image_size, is_inf); - } -} - -struct clip_model_loader { - ggml_context_ptr ctx_meta; - gguf_context_ptr ctx_gguf; - - clip_ctx & ctx_clip; - std::string fname; - - size_t model_size; // in bytes - - // TODO @ngxson : we should not pass clip_ctx here, it should be clip_vision_model - clip_model_loader(const char * fname, clip_ctx & ctx_clip) : ctx_clip(ctx_clip), fname(fname) { - struct ggml_context * meta = nullptr; - - struct gguf_init_params params = { - /*.no_alloc = */ true, - /*.ctx = */ &meta, - }; - - ctx_gguf = gguf_context_ptr(gguf_init_from_file(fname, params)); - if (!ctx_gguf.get()) { - throw std::runtime_error(string_format("%s: failed to load CLIP model from %s. Does this file exist?\n", __func__, fname)); - } - - ctx_meta.reset(meta); - - const int n_tensors = gguf_get_n_tensors(ctx_gguf.get()); - - // print gguf info - { - std::string name; - get_string(KEY_NAME, name, false); - std::string description; - get_string(KEY_DESCRIPTION, description, false); - LOG_INF("%s: model name: %s\n", __func__, name.c_str()); - LOG_INF("%s: description: %s\n", __func__, description.c_str()); - LOG_INF("%s: GGUF version: %d\n", __func__, gguf_get_version(ctx_gguf.get())); - LOG_INF("%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx_gguf.get())); - LOG_INF("%s: n_tensors: %d\n", __func__, n_tensors); - LOG_INF("%s: n_kv: %d\n", __func__, (int)gguf_get_n_kv(ctx_gguf.get())); - LOG_INF("\n"); - } - - // tensors - { - for (int i = 0; i < n_tensors; ++i) { - const char * name = gguf_get_tensor_name(ctx_gguf.get(), i); - const size_t offset = gguf_get_tensor_offset(ctx_gguf.get(), i); - enum ggml_type type = gguf_get_tensor_type(ctx_gguf.get(), i); - struct ggml_tensor * cur = ggml_get_tensor(meta, name); - size_t tensor_size = ggml_nbytes(cur); - model_size += tensor_size; - LOG_DBG("%s: tensor[%d]: n_dims = %d, name = %s, tensor_size=%zu, offset=%zu, shape:[%" PRIu64 ", %" PRIu64 ", %" PRIu64 ", %" PRIu64 "], type = %s\n", - __func__, i, ggml_n_dims(cur), cur->name, tensor_size, offset, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_type_name(type)); - } - } - } - - void load_hparams() { - // projector type - { - std::string proj_type; - get_string(KEY_PROJ_TYPE, proj_type, false); - if (!proj_type.empty()) { - ctx_clip.proj_type = clip_projector_type_from_string(proj_type); - } - if (ctx_clip.proj_type == PROJECTOR_TYPE_UNKNOWN) { - throw std::runtime_error(string_format("%s: unknown projector type: %s\n", __func__, proj_type.c_str())); - } - } - - // other hparams - { - get_bool(KEY_HAS_TEXT_ENC, ctx_clip.has_text_encoder, false); - get_bool(KEY_HAS_VIS_ENC, ctx_clip.has_vision_encoder, false); - GGML_ASSERT(ctx_clip.has_vision_encoder); - GGML_ASSERT(!ctx_clip.has_text_encoder); - - // legacy keys, use KEY_PROJ_TYPE instead - get_bool(KEY_HAS_LLAVA_PROJ, ctx_clip.has_llava_projector, false); - get_bool(KEY_HAS_MINICPMV_PROJ, ctx_clip.has_minicpmv_projector, false); - get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false); - get_bool(KEY_HAS_GLM_PROJ, ctx_clip.has_glm_projector, false); - get_bool(KEY_HAS_QWEN2VL_MERGER, ctx_clip.has_qwen2vl_merger, false); - // !!! do NOT extend the list above, use KEY_PROJ_TYPE instead - - get_bool(KEY_USE_GELU, ctx_clip.use_gelu, false); - get_bool(KEY_USE_SILU, ctx_clip.use_silu, false); - - auto & hparams = ctx_clip.vision_model.hparams; - get_u32(string_format(KEY_N_EMBD, "vision"), hparams.hidden_size); - get_u32(string_format(KEY_N_HEAD, "vision"), hparams.n_head); - get_u32(string_format(KEY_N_FF, "vision"), hparams.n_intermediate); - get_u32(string_format(KEY_N_BLOCK, "vision"), hparams.n_layer); - get_u32(string_format(KEY_PROJ_DIM, "vision"), hparams.projection_dim); - get_f32(string_format(KEY_LAYER_NORM_EPS, "vision"), hparams.eps); - get_u32(KEY_IMAGE_SIZE, hparams.image_size); - get_u32(KEY_PATCH_SIZE, hparams.patch_size); - get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false); - get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false); - - { - std::string mm_patch_merge_type; - get_string(KEY_MM_PATCH_MERGE_TYPE, mm_patch_merge_type, false); - if (mm_patch_merge_type == "spatial_unpad") { - hparams.mm_patch_merge_type = PATCH_MERGE_SPATIAL_UNPAD; - } - } - - { - int idx_mean = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_MEAN); - int idx_std = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_STD); - GGML_ASSERT(idx_mean >= 0 && "image_mean not found"); - GGML_ASSERT(idx_std >= 0 && "image_std not found"); - const float * mean_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_mean); - const float * std_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_std); - for (int i = 0; i < 3; ++i) { - ctx_clip.image_mean[i] = mean_data[i]; - ctx_clip.image_std[i] = std_data[i]; - } - } - - // Load the vision feature layer indices if they are explicitly provided; - // if multiple vision feature layers are present, the values will be concatenated - // to form the final visual features. - // NOTE: gguf conversions should standardize the values of the vision feature layer to - // be non-negative, since we use -1 to mark values as unset here. - std::vector vision_feature_layer; - get_arr_int(KEY_FEATURE_LAYER, vision_feature_layer, false); - // convert std::vector to std::unordered_set - for (auto & layer : vision_feature_layer) { - hparams.vision_feature_layer.insert(layer); - } - // Calculate the deepest feature layer based on hparams and projector type - ctx_clip.max_feature_layer = get_deepest_feature_layer(&ctx_clip); - - LOG_INF("%s: text_encoder: %d\n", __func__, ctx_clip.has_text_encoder); - LOG_INF("%s: vision_encoder: %d\n", __func__, ctx_clip.has_vision_encoder); - LOG_INF("%s: llava_projector: %d\n", __func__, ctx_clip.has_llava_projector); - LOG_INF("%s: minicpmv_projector: %d\n", __func__, ctx_clip.has_minicpmv_projector); - LOG_INF("%s: minicpmv_version: %d\n", __func__, ctx_clip.minicpmv_version); - LOG_INF("%s: glm_projector: %d\n", __func__, ctx_clip.has_glm_projector); - LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0); - LOG_INF("%s: metadata size: %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0); - } - } - - void load_tensors() { - std::map tensor_offset; - std::vector tensors_to_load; - - // get offsets - for (int64_t i = 0; i < gguf_get_n_tensors(ctx_gguf.get()); ++i) { - const char * name = gguf_get_tensor_name(ctx_gguf.get(), i); - tensor_offset[name] = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), i); - } - - // create data context - struct ggml_init_params params = { - /*.mem_size =*/ (gguf_get_n_tensors(ctx_gguf.get()) + 1) * ggml_tensor_overhead(), - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - ctx_clip.ctx_data.reset(ggml_init(params)); - if (!ctx_clip.ctx_data) { - throw std::runtime_error(string_format("%s: failed to init ggml context\n", __func__)); - } - - // helper function - auto get_tensor = [&](const std::string & name, bool required = true) { - struct ggml_tensor * cur = ggml_get_tensor(ctx_meta.get(), name.c_str()); - if (!cur && required) { - throw std::runtime_error(string_format("%s: unable to find tensor %s\n", __func__, name.c_str())); - } - if (cur) { - tensors_to_load.push_back(cur); - // add tensors to context - struct ggml_tensor * data_tensor = ggml_dup_tensor(ctx_clip.ctx_data.get(), cur); - ggml_set_name(data_tensor, cur->name); - cur = data_tensor; - } - return cur; - }; - - auto & vision_model = ctx_clip.vision_model; - - vision_model.class_embedding = get_tensor(TN_CLASS_EMBD, false); - - vision_model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, "v", "weight"), false); - vision_model.pre_ln_b = get_tensor(string_format(TN_LN_PRE, "v", "bias"), false); - - vision_model.post_ln_w = get_tensor(string_format(TN_LN_POST, "v", "weight"), false); - vision_model.post_ln_b = get_tensor(string_format(TN_LN_POST, "v", "bias"), false); - - vision_model.patch_bias = get_tensor(TN_PATCH_BIAS, false); - vision_model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false); - vision_model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false); - if (vision_model.patch_embeddings_1 == nullptr) { - ctx_clip.has_qwen2vl_merger = false; - } - - vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, "v"), false); - - // layers - vision_model.layers.resize(vision_model.hparams.n_layer); - for (int il = 0; il < vision_model.hparams.n_layer; ++il) { - auto & layer = vision_model.layers[il]; - layer.k_w = get_tensor(string_format(TN_ATTN_K, "v", il, "weight")); - layer.q_w = get_tensor(string_format(TN_ATTN_Q, "v", il, "weight")); - layer.v_w = get_tensor(string_format(TN_ATTN_V, "v", il, "weight")); - layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "weight")); - layer.ln_1_w = get_tensor(string_format(TN_LN_1, "v", il, "weight"), false); - layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false); - layer.ff_i_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight")); - layer.ff_o_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight")); - layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false); - layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false); - layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false); - layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "bias"), false); - layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false); - layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false); - layer.ff_i_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false); - layer.ff_o_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false); - } - - switch (ctx_clip.proj_type) { - case PROJECTOR_TYPE_MLP: - case PROJECTOR_TYPE_MLP_NORM: - { - // LLaVA projection - vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"), false); - vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"), false); - // Yi-type llava - vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"), false); - vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false); - // missing in Yi-type llava - vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"), false); - vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false); - // Yi-type llava - vision_model.mm_3_w = get_tensor(string_format(TN_LLAVA_PROJ, 3, "weight"), false); - vision_model.mm_3_b = get_tensor(string_format(TN_LLAVA_PROJ, 3, "bias"), false); - vision_model.mm_4_w = get_tensor(string_format(TN_LLAVA_PROJ, 4, "weight"), false); - vision_model.mm_4_b = get_tensor(string_format(TN_LLAVA_PROJ, 4, "bias"), false); - if (vision_model.mm_3_w) { - // TODO: this is a hack to support Yi-type llava - ctx_clip.proj_type = PROJECTOR_TYPE_MLP_NORM; - } - vision_model.image_newline = get_tensor(TN_IMAGE_NEWLINE, false); - } break; - case PROJECTOR_TYPE_LDP: - { - // MobileVLM projection - vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight")); - vision_model.mm_model_mlp_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias")); - vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight")); - vision_model.mm_model_mlp_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias")); - vision_model.mm_model_block_1_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "0.weight")); - vision_model.mm_model_block_1_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.weight")); - vision_model.mm_model_block_1_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.bias")); - vision_model.mm_model_block_1_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.weight")); - vision_model.mm_model_block_1_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.bias")); - vision_model.mm_model_block_1_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.weight")); - vision_model.mm_model_block_1_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.bias")); - vision_model.mm_model_block_1_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "0.weight")); - vision_model.mm_model_block_1_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.weight")); - vision_model.mm_model_block_1_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.bias")); - vision_model.mm_model_block_2_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "0.weight")); - vision_model.mm_model_block_2_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.weight")); - vision_model.mm_model_block_2_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.bias")); - vision_model.mm_model_block_2_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.weight")); - vision_model.mm_model_block_2_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.bias")); - vision_model.mm_model_block_2_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.weight")); - vision_model.mm_model_block_2_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.bias")); - vision_model.mm_model_block_2_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "0.weight")); - vision_model.mm_model_block_2_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.weight")); - vision_model.mm_model_block_2_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.bias")); - } break; - case PROJECTOR_TYPE_LDPV2: - { - // MobilVLM_V2 projection - vision_model.mm_model_mlp_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight")); - vision_model.mm_model_mlp_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias")); - vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight")); - vision_model.mm_model_mlp_2_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "bias")); - vision_model.mm_model_peg_0_w = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "weight")); - vision_model.mm_model_peg_0_b = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "bias")); - } break; - case PROJECTOR_TYPE_RESAMPLER: - { - // vision_model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD); - vision_model.mm_model_pos_embed_k = get_tensor(TN_MINICPMV_POS_EMBD_K); - vision_model.mm_model_query = get_tensor(TN_MINICPMV_QUERY); - vision_model.mm_model_proj = get_tensor(TN_MINICPMV_PROJ); - vision_model.mm_model_kv_proj = get_tensor(TN_MINICPMV_KV_PROJ); - vision_model.mm_model_attn_q_w = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "weight")); - vision_model.mm_model_attn_k_w = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "weight")); - vision_model.mm_model_attn_v_w = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "weight")); - vision_model.mm_model_attn_q_b = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "bias")); - vision_model.mm_model_attn_k_b = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "bias")); - vision_model.mm_model_attn_v_b = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "bias")); - vision_model.mm_model_attn_o_w = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "weight")); - vision_model.mm_model_attn_o_b = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "bias")); - vision_model.mm_model_ln_q_w = get_tensor(string_format(TN_MINICPMV_LN, "q", "weight")); - vision_model.mm_model_ln_q_b = get_tensor(string_format(TN_MINICPMV_LN, "q", "bias")); - vision_model.mm_model_ln_kv_w = get_tensor(string_format(TN_MINICPMV_LN, "kv", "weight")); - vision_model.mm_model_ln_kv_b = get_tensor(string_format(TN_MINICPMV_LN, "kv", "bias")); - vision_model.mm_model_ln_post_w = get_tensor(string_format(TN_MINICPMV_LN, "post", "weight")); - vision_model.mm_model_ln_post_b = get_tensor(string_format(TN_MINICPMV_LN, "post", "bias")); - } break; - case PROJECTOR_TYPE_GLM_EDGE: - { - vision_model.mm_model_adapter_conv_w = get_tensor(string_format(TN_GLM_ADAPER_CONV, "weight")); - vision_model.mm_model_adapter_conv_b = get_tensor(string_format(TN_GLM_ADAPER_CONV, "bias")); - vision_model.mm_model_mlp_0_w = get_tensor(string_format(TN_GLM_ADAPTER_LINEAR,"weight")); - vision_model.mm_model_ln_q_w = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1,"weight")); - vision_model.mm_model_ln_q_b = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1,"bias")); - vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_GLM_ADAPTER_D_H_2_4H,"weight")); - vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_GLM_ADAPTER_GATE,"weight")); - vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H,"weight")); - vision_model.boi_w = get_tensor(TN_GLM_BOI_W); - vision_model.eoi_w = get_tensor(TN_GLM_EOI_W); - } break; - case PROJECTOR_TYPE_MERGER: - { - vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); - vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias")); - vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); - vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); - } break; - case PROJECTOR_TYPE_GEMMA3: - { - vision_model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ); - vision_model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N); - } break; - default: - GGML_ASSERT(false && "unknown projector type"); - } - - // load data - { - std::vector read_buf; - - auto fin = std::ifstream(fname, std::ios::binary); - if (!fin) { - throw std::runtime_error(string_format("%s: failed to open %s\n", __func__, fname.c_str())); - } - - // alloc memory and offload data - ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend); - ctx_clip.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data.get(), buft)); - ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS); - for (auto & t : tensors_to_load) { - struct ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name); - const size_t offset = tensor_offset[t->name]; - fin.seekg(offset, std::ios::beg); - if (!fin) { - throw std::runtime_error(string_format("%s: failed to seek for tensor %s\n", __func__, t->name)); - } - size_t num_bytes = ggml_nbytes(cur); - if (ggml_backend_buft_is_host(buft)) { - // for the CPU and Metal backend, we can read directly into the tensor - fin.read(reinterpret_cast(cur->data), num_bytes); - } else { - // read into a temporary buffer first, then copy to device memory - read_buf.resize(num_bytes); - fin.read(reinterpret_cast(read_buf.data()), num_bytes); - ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes); - } - } - fin.close(); - - LOG_DBG("%s: loaded %zu tensors from %s\n", __func__, tensors_to_load.size(), fname.c_str()); - } - } - - void alloc_compute_meta() { - ctx_clip.buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead()); - - // create a fake batch - clip_image_f32_batch batch; - clip_image_f32_ptr img(clip_image_f32_init()); - clip_image_size image_size; - image_size.width = clip_get_image_size(&ctx_clip); - image_size.height = clip_get_image_size(&ctx_clip); - int n_patches = clip_get_image_size(&ctx_clip) / image_size.width; - img->nx = n_patches; - img->ny = n_patches; - img->buf.resize(n_patches * image_size.width * image_size.height * 3); - batch.entries.push_back(std::move(img)); - - ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch, image_size, false); - ggml_backend_sched_reserve(ctx_clip.sched.get(), gf); - for (size_t i = 0; i < ctx_clip.backend_ptrs.size(); ++i) { - ggml_backend_t backend = ctx_clip.backend_ptrs[i]; - ggml_backend_buffer_type_t buft = ctx_clip.backend_buft[i]; - size_t size = ggml_backend_sched_get_buffer_size(ctx_clip.sched.get(), backend); - if (size > 1) { - LOG_INF("%s: %10s compute buffer size = %8.2f MiB\n", __func__, - ggml_backend_buft_name(buft), - size / 1024.0 / 1024.0); - } - } - } - - void get_bool(const std::string & key, bool & output, bool required = true) { - const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); - if (i < 0) { - if (required) throw std::runtime_error("Key not found: " + key); - return; - } - output = gguf_get_val_bool(ctx_gguf.get(), i); - } - - void get_i32(const std::string & key, int & output, bool required = true) { - const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); - if (i < 0) { - if (required) throw std::runtime_error("Key not found: " + key); - return; - } - output = gguf_get_val_i32(ctx_gguf.get(), i); - } - - void get_u32(const std::string & key, int & output, bool required = true) { - const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); - if (i < 0) { - if (required) throw std::runtime_error("Key not found: " + key); - return; - } - output = gguf_get_val_u32(ctx_gguf.get(), i); - } - - void get_f32(const std::string & key, float & output, bool required = true) { - const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); - if (i < 0) { - if (required) throw std::runtime_error("Key not found: " + key); - return; - } - output = gguf_get_val_f32(ctx_gguf.get(), i); - } - - void get_string(const std::string & key, std::string & output, bool required = true) { - const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); - if (i < 0) { - if (required) throw std::runtime_error("Key not found: " + key); - return; - } - output = std::string(gguf_get_val_str(ctx_gguf.get(), i)); - } - - void get_arr_int(const std::string & key, std::vector & output, bool required = true) { - const int i = gguf_find_key(ctx_gguf.get(), key.c_str()); - if (i < 0) { - if (required) throw std::runtime_error("Key not found: " + key); - return; - } - int n = gguf_get_arr_n(ctx_gguf.get(), i); - output.resize(n); - const int32_t * values = (const int32_t *)gguf_get_arr_data(ctx_gguf.get(), i); - for (int i = 0; i < n; ++i) { - output[i] = values[i]; - } - } -}; - -// read and create ggml_context containing the tensors and their data -struct clip_ctx * clip_model_load(const char * fname, const int verbosity) { - return clip_init(fname, clip_context_params{ - /* use_gpu */ true, - /* verbosity */ static_cast(verbosity), - }); -} - -struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_params) { - g_logger_state.verbosity_thold = ctx_params.verbosity; - clip_ctx * ctx_clip = new clip_ctx(ctx_params); - - try { - clip_model_loader loader(fname, *ctx_clip); - loader.load_hparams(); - loader.load_tensors(); - loader.alloc_compute_meta(); - } catch (const std::exception & e) { - LOG_ERR("%s: failed to load model '%s': %s\n", __func__, fname, e.what()); - delete ctx_clip; - return nullptr; - } - - return ctx_clip; -} - -void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size) { - ctx_clip->load_image_size = *load_image_size; // copy -} - -struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip) { - return &ctx_clip->load_image_size; -} - -struct clip_image_size * clip_image_size_init() { - struct clip_image_size * load_image_size = new struct clip_image_size(); - load_image_size->width = 448; - load_image_size->height = 448; - return load_image_size; -} - -struct clip_image_u8 * clip_image_u8_init() { - return new clip_image_u8(); -} - -struct clip_image_f32 * clip_image_f32_init() { - return new clip_image_f32(); -} - -struct clip_image_f32_batch * clip_image_f32_batch_init() { - return new clip_image_f32_batch(); -} - -unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny) { - if (nx) *nx = img->nx; - if (ny) *ny = img->ny; - return img->buf.data(); -} - -void clip_image_size_free(struct clip_image_size * load_image_size) { - if (load_image_size == nullptr) { - return; - } - delete load_image_size; -} -void clip_image_u8_free(struct clip_image_u8 * img) { if (img) delete img; } -void clip_image_f32_free(struct clip_image_f32 * img) { if (img) delete img; } -void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { if (batch) delete batch; } -void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { if (batch) delete batch; } - -size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch) { - return batch->entries.size(); -} - -size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx) { - if (idx < 0 || idx >= (int)batch->entries.size()) { - LOG_ERR("%s: invalid index %d\n", __func__, idx); - return 0; - } - return batch->entries[idx]->nx; -} - -size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx) { - if (idx < 0 || idx >= (int)batch->entries.size()) { - LOG_ERR("%s: invalid index %d\n", __func__, idx); - return 0; - } - return batch->entries[idx]->ny; -} - -clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx) { - if (idx < 0 || idx >= (int)batch->entries.size()) { - LOG_ERR("%s: invalid index %d\n", __func__, idx); - return nullptr; - } - return batch->entries[idx].get(); -} - -void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, clip_image_u8 * img) { - img->nx = nx; - img->ny = ny; - img->buf.resize(3 * nx * ny); - memcpy(img->buf.data(), rgb_pixels, img->buf.size()); -} - -bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) { - int nx, ny, nc; - auto * data = stbi_load(fname, &nx, &ny, &nc, 3); - if (!data) { - LOG_ERR("%s: failed to load image '%s'\n", __func__, fname); - return false; - } - clip_build_img_from_pixels(data, nx, ny, img); - stbi_image_free(data); - return true; -} - -bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img) { - int nx, ny, nc; - auto * data = stbi_load_from_memory(bytes, bytes_length, &nx, &ny, &nc, 3); - if (!data) { - LOG_ERR("%s: failed to decode image bytes\n", __func__); - return false; - } - clip_build_img_from_pixels(data, nx, ny, img); - stbi_image_free(data); - return true; -} - -// Linear interpolation between two points -inline float clip_lerp(float s, float e, float t) { - return s + (e - s) * t; -} -// Bilinear resize function -static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) { - dst.nx = target_width; - dst.ny = target_height; - dst.buf.resize(3 * target_width * target_height); - - float x_ratio = static_cast(src.nx - 1) / target_width; - float y_ratio = static_cast(src.ny - 1) / target_height; - - for (int y = 0; y < target_height; y++) { - for (int x = 0; x < target_width; x++) { - float px = x_ratio * x; - float py = y_ratio * y; - int x_floor = static_cast(px); - int y_floor = static_cast(py); - float x_lerp = px - x_floor; - float y_lerp = py - y_floor; - - for (int c = 0; c < 3; c++) { - float top = clip_lerp( - static_cast(src.buf[3 * (y_floor * src.nx + x_floor) + c]), - static_cast(src.buf[3 * (y_floor * src.nx + (x_floor + 1)) + c]), - x_lerp - ); - float bottom = clip_lerp( - static_cast(src.buf[3 * ((y_floor + 1) * src.nx + x_floor) + c]), - static_cast(src.buf[3 * ((y_floor + 1) * src.nx + (x_floor + 1)) + c]), - x_lerp - ); - dst.buf[3 * (y * target_width + x) + c] = static_cast(clip_lerp(top, bottom, y_lerp)); - } - } - } -} - -// Normalize image to float32 - careful with pytorch .to(model.device, dtype=torch.float16) - this sometimes reduces precision (32>16>32), sometimes not -static void normalize_image_u8_to_f32(const clip_image_u8 & src, clip_image_f32 & dst, const float mean[3], const float std[3]) { - dst.nx = src.nx; - dst.ny = src.ny; - dst.buf.resize(src.buf.size()); - - // TODO @ngxson : seems like this could be done more efficiently on cgraph - for (size_t i = 0; i < src.buf.size(); ++i) { - int c = i % 3; // rgb - dst.buf[i] = (static_cast(src.buf[i]) / 255.0f - mean[c]) / std[c]; - } -} - -inline int clip(int x, int lower, int upper) { - return std::max(lower, std::min(x, upper)); -} - -static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) { - const int nx = img.nx; - const int ny = img.ny; - - dst.nx = target_width; - dst.ny = target_height; - dst.buf.resize(3 * target_width * target_height); - - float Cc; - float C[5]; - float d0, d2, d3, a0, a1, a2, a3; - int i, j, k, jj; - int x, y; - float dx, dy; - float tx, ty; - - tx = (float)nx / (float)target_width; - ty = (float)ny / (float)target_height; - - // Bicubic interpolation; adapted from ViT.cpp, inspired from : - // -> https://github.com/yglukhov/bicubic-interpolation-image-processing/blob/master/libimage.c#L36 - // -> https://en.wikipedia.org/wiki/Bicubic_interpolation - - for (i = 0; i < target_height; i++) { - for (j = 0; j < target_width; j++) { - x = (int)(tx * j); - y = (int)(ty * i); - - dx = tx * j - x; - dy = ty * i - y; - - for (k = 0; k < 3; k++) { - for (jj = 0; jj <= 3; jj++) { - d0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x - 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; - d2 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; - d3 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 2, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; - a0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; - - a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3; - a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; - a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; - - C[jj] = a0 + a1 * dx + a2 * dx * dx + a3 * dx * dx * dx; - - d0 = C[0] - C[1]; - d2 = C[2] - C[1]; - d3 = C[3] - C[1]; - a0 = C[1]; - a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3; - a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; - a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; - Cc = a0 + a1 * dy + a2 * dy * dy + a3 * dy * dy * dy; - - const uint8_t Cc2 = std::min(std::max(std::round(Cc), 0.0f), 255.0f); - dst.buf[(i * target_width + j) * 3 + k] = float(Cc2); - } - } - } - } - - return true; -} - -// llava-1.6 type of resize_and_pad (black) -static void resize_and_pad_image(const clip_image_u8& image, clip_image_u8 &image_output, const std::pair& target_resolution) { - int target_width = target_resolution.first; - int target_height = target_resolution.second; - - float scale_w = static_cast(target_width) / image.nx; - float scale_h = static_cast(target_height) / image.ny; - - int new_width, new_height; - - if (scale_w < scale_h) { - new_width = target_width; - new_height = std::min(static_cast(std::ceil(image.ny * scale_w)), target_height); - } else { - new_height = target_height; - new_width = std::min(static_cast(std::ceil(image.nx * scale_h)), target_width); - } - - clip_image_u8 resized_image; - // bilinear_resize(image, resized_image, new_width, new_height); - bicubic_resize(image, resized_image, new_width, new_height); - - clip_image_u8 padded_image; - padded_image.nx = target_width; - padded_image.ny = target_height; - padded_image.buf.resize(3 * target_width * target_height, 0); // Initialize with black - - // Calculate padding offsets - int pad_x = (target_width - new_width) / 2; - int pad_y = (target_height - new_height) / 2; - - // Copy the resized image into the center of the padded buffer - for (int y = 0; y < new_height; ++y) { - for (int x = 0; x < new_width; ++x) { - for (int c = 0; c < 3; ++c) { - padded_image.buf[3 * ((y + pad_y) * target_width + (x + pad_x)) + c] = resized_image.buf[3 * (y * new_width + x) + c]; - } - } - } - image_output = std::move(padded_image); -} - -/** - * Selects the best resolution from a list of possible resolutions based on the original size. - * - * @param original_size The original size of the image in the format (width, height). - * @param possible_resolutions A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. - * @return The best fit resolution in the format (width, height). - */ -static std::pair select_best_resolution(const std::pair & original_size, const std::vector> & possible_resolutions) { - int original_width = original_size.first; - int original_height = original_size.second; - std::pair best_fit; - int max_effective_resolution = 0; - int min_wasted_resolution = std::numeric_limits::max(); - - for (const auto& resolution : possible_resolutions) { - int width = resolution.first; - int height = resolution.second; - float scale = std::min(static_cast(width) / original_width, static_cast(height) / original_height); - int downscaled_width = static_cast(original_width * scale); - int downscaled_height = static_cast(original_height * scale); - int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height); - int wasted_resolution = (width * height) - effective_resolution; - // LOG_INF("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution); - if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) { - max_effective_resolution = effective_resolution; - min_wasted_resolution = wasted_resolution; - best_fit = resolution; - } - } - - return best_fit; -} - -static std::vector divide_to_patches_u8(const clip_image_u8 & image, int patch_size) { - std::vector patches; - int width = image.nx; - int height = image.ny; - for (int i = 0; i < height; i += patch_size) { - for (int j = 0; j < width; j += patch_size) { - clip_image_u8_ptr patch(clip_image_u8_init()); - patch->nx = std::min(patch_size, width - j); - patch->ny = std::min(patch_size, height - i); - patch->buf.resize(3 * patch->nx * patch->ny); - for (int y = 0; y < patch->ny; ++y) { - for (int x = 0; x < patch->nx; ++x) { - for (int c = 0; c < 3; ++c) { - patch->buf[3 * (y * patch->nx + x) + c] = image.buf[3 * ((i + y) * width + (j + x)) + c]; - } - } - } - patches.push_back(std::move(patch)); - } - } - return patches; -} - -static int ensure_divide(int length, int patch_size) { - return std::max(static_cast(std::round(static_cast(length) / patch_size) * patch_size), patch_size); -} - -static std::pair uhd_find_best_resize(std::pair original_size, int scale_resolution, int patch_size, bool allow_upscale = false) { - int width = original_size.first; - int height = original_size.second; - if ((width * height > scale_resolution * scale_resolution) || allow_upscale) { - float r = static_cast(width) / height; - height = static_cast(scale_resolution / std::sqrt(r)); - width = static_cast(height * r); - } - int best_width = ensure_divide(width, patch_size); - int best_height = ensure_divide(height, patch_size); - return std::make_pair(best_width, best_height); -} - -static std::pair uhd_get_refine_size(std::pair original_size, std::pair grid, int scale_resolution, int patch_size, bool allow_upscale = false) { - int width, height; - std::tie(width, height) = original_size; - int grid_x, grid_y; - std::tie(grid_x, grid_y) = grid; - - int refine_width = ensure_divide(width, grid_x); - int refine_height = ensure_divide(height, grid_y); - - int grid_width = refine_width / grid_x; - int grid_height = refine_height / grid_y; - - // auto best_grid_size = find_best_resize(std::make_tuple(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); (old line) - auto best_grid_size = uhd_find_best_resize(std::make_pair(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); // (new line) => fixes conversion for make_tuple to make_pair - int best_grid_width, best_grid_height; - std::tie(best_grid_width, best_grid_height) = best_grid_size; - - // std::pair refine_size = std::make_tuple(best_grid_width * grid_x, best_grid_height * grid_y); (old line) - std::pair refine_size = std::make_pair(best_grid_width * grid_x, best_grid_height * grid_y); // (new line) - return refine_size; -} - -static std::pair uhd_best_grid(const int max_slice_nums, const int multiple, const float log_ratio) { - std::vector candidate_split_grids_nums; - for (int i : {multiple - 1, multiple, multiple + 1}) { - if (i == 1 || i > max_slice_nums) { - continue; - } - candidate_split_grids_nums.push_back(i); - } - - std::vector> candidate_grids; - for (int split_grids_nums : candidate_split_grids_nums) { - int m = 1; - while (m <= split_grids_nums) { - if (split_grids_nums % m == 0) { - candidate_grids.emplace_back(m, split_grids_nums / m); - } - ++m; - } - } - - std::pair best_grid{1, 1}; - float min_error = std::numeric_limits::infinity(); - for (const auto& grid : candidate_grids) { - float error = std::abs(log_ratio - std::log(1.0 * grid.first / grid.second)); - if (error < min_error) { - best_grid = grid; - min_error = error; - } - } - return best_grid; -} - -// inspired from LLaVA-UHD: -// -> https://arxiv.org/pdf/2403.11703 -// -> https://github.com/thunlp/LLaVA-UHD -// -> https://github.com/thunlp/LLaVA-UHD/blob/302301bc2175f7e717fb8548516188e89f649753/llava_uhd/train/llava-uhd/slice_logic.py#L118 -static std::vector> uhd_slice_image(const clip_image_u8 * img, const int max_slice_nums=9, const int scale_resolution=448, const int patch_size=14) { - const std::pair original_size={img->nx,img->ny}; - const int original_width = img->nx; - const int original_height = img->ny; - const float log_ratio = log(1.0*original_width/original_height); - const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution); - const int multiple = fmin(ceil(ratio), max_slice_nums); - - std::vector> images; - LOG_DBG("%s: multiple %d\n", __func__, multiple); - images.push_back(std::vector()); - - if (multiple <= 1) { - auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size, true); - clip_image_u8_ptr source_image(clip_image_u8_init()); - bicubic_resize(*img, *source_image, best_size.first, best_size.second); - // source_image = image.resize(best_size, Image.Resampling.BICUBIC) - images.back().push_back(std::move(source_image)); - } - else if (multiple > 1) { - auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size); - clip_image_u8_ptr source_image(clip_image_u8_init()); - bicubic_resize(*img, *source_image, best_size.first, best_size.second); - // source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC) - LOG_DBG("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img->nx, img->ny, best_size.first, best_size.second); - images.back().push_back(std::move(source_image)); - - std::pair best_grid = uhd_best_grid(max_slice_nums, multiple, log_ratio); - LOG_DBG("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img->nx, img->ny, best_grid.first, best_grid.second); - - auto refine_size = uhd_get_refine_size(original_size, best_grid, scale_resolution, patch_size, true); - clip_image_u8_ptr refine_image(clip_image_u8_init()); - bicubic_resize(*img, *refine_image, refine_size.first, refine_size.second); - - LOG_DBG("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image->nx, refine_image->ny, refine_size.first, refine_size.second); - - // split_to_patches - int width = refine_image->nx; - int height = refine_image->ny; - int grid_x = int(width / best_grid.first); - int grid_y = int(height / best_grid.second); - for (int patches_i = 0, ic = 0; patches_i < height && ic < best_grid.second; patches_i += grid_y, ic += 1){ - images.push_back(std::vector()); - for(int patches_j = 0, jc = 0; patches_j < width && jc < best_grid.first; patches_j += grid_x, jc += 1){ - clip_image_u8_ptr patch(clip_image_u8_init()); - patch->nx = grid_x; - patch->ny = grid_y; - patch->buf.resize(3 * patch->nx * patch->ny); - for (int y = patches_i; y < patches_i + grid_y; ++y) { - for (int x = patches_j; x < patches_j + grid_x; ++x) { - const int i = 3 * (y * refine_image->nx + x); - const int j = 3 * ((y-patches_i) * patch->nx + (x-patches_j)); - patch->buf[j] = refine_image->buf[i]; - patch->buf[j+1] = refine_image->buf[i+1]; - patch->buf[j+2] = refine_image->buf[i+2]; - } - } - images.back().push_back(std::move(patch)); - } - } - } - return images; -} - -int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) { - const int max_slice_nums=9; - const int scale_resolution=448; - const int original_width = ctx_clip->load_image_size.width; - const int original_height = ctx_clip->load_image_size.height; - const float log_ratio = log(1.0*original_width/original_height); - const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution); - const int multiple = fmin(ceil(ratio), max_slice_nums); - std::pair best_grid = uhd_best_grid(max_slice_nums, multiple, log_ratio); - return best_grid.first; -} - -// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector -// res_imgs memory is being allocated here, previous allocations will be freed if found -bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) { - - if (clip_is_minicpmv(ctx)) { - int max_slice_nums = 9; - std::vector> imgs = uhd_slice_image(img, max_slice_nums); - for (size_t i = 0; i < imgs.size(); ++i) { - for (size_t j = 0; j < imgs[i].size(); ++j) { - LOG_DBG("%s: %d %d\n", __func__,imgs[i][j]->nx,imgs[i][j]->ny); - clip_image_f32_ptr res(clip_image_f32_init()); - normalize_image_u8_to_f32(*imgs[i][j], *res, ctx->image_mean, ctx->image_std); - res_imgs->entries.push_back(std::move(res)); - } - } - return true; - } - else if (ctx->has_qwen2vl_merger) { - clip_image_u8 resized; - auto patch_size = clip_get_patch_size(ctx) * 2; - int nx = ceil((float)img->nx / patch_size) * patch_size; - int ny = ceil((float)img->ny / patch_size) * patch_size; - bicubic_resize(*img, resized, nx, ny); - - clip_image_f32_ptr img_f32(clip_image_f32_init()); - // clip_image_f32_ptr res(clip_image_f32_init()); - normalize_image_u8_to_f32(resized, *img_f32, ctx->image_mean, ctx->image_std); - // res_imgs->data[0] = *res; - res_imgs->entries.push_back(std::move(img_f32)); - return true; - } - - if (ctx->has_glm_projector || ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { - clip_image_u8 resized_image; - int32_t sz=ctx->vision_model.hparams.image_size; - bicubic_resize(*img, resized_image,sz,sz); - clip_image_f32_ptr img_f32(clip_image_f32_init()); - //clip_image_save_to_bmp(resized_image, "resized.bmp"); - normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std); - res_imgs->entries.push_back(std::move(img_f32)); - return true; - } - - bool pad_to_square = true; - if (!ctx->has_vision_encoder) { - LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__); - return false; - } - auto & params = ctx->vision_model.hparams; - // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing - if (params.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD) { - pad_to_square = false; - } - // free the previous res_imgs if any set - res_imgs->entries.clear(); - - // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104) - // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156 - - clip_image_u8_ptr temp(clip_image_u8_init()); // we will keep the input image data here temporarily - if (pad_to_square && img->nx != img->ny) { - int longer_side = std::max(img->nx, img->ny); - temp->nx = longer_side; - temp->ny = longer_side; - temp->buf.resize(3 * longer_side * longer_side); - const uint8_t bc[3] = {122, 116, 104}; // background color in RGB from LLaVA (this is the mean rgb color * 255) - - // fill with background color - for (size_t i = 0; i < temp->buf.size(); i++) { - temp->buf[i] = bc[i % 3]; - } - - // copy from the input image - for (int y = 0; y < img->ny; y++) { - for (int x = 0; x < img->nx; x++) { - const int i = 3 * (y * img->nx + x); - const int j = 3 * (y * temp->nx + x); - temp->buf[j] = img->buf[i]; - temp->buf[j+1] = img->buf[i+1]; - temp->buf[j+2] = img->buf[i+2]; - } - } - } else { - if (!params.image_grid_pinpoints.empty()) { - // "spatial_unpad" with "anyres" processing for llava-1.6 - std::vector> possible_resolutions; - for (size_t i = 0; i < params.image_grid_pinpoints.size(); i+=2) { - possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]}); - } - std::pair best_resolution = select_best_resolution({img->nx, img->ny}, possible_resolutions); - // clip_image_save_to_bmp(*img, "input.bmp"); - resize_and_pad_image(*img, *temp, best_resolution); // we do not pad with mean-bg color anymore in llava-1.6 - // clip_image_save_to_bmp(*temp, "resized.bmp"); - // visually verify normalized image: - // normalize_image_u8_to_f32(*temp, *res, ctx->image_mean, ctx->image_std); - // { - // clip_image_u8 * temp2 = clip_image_u8_init(); - // clip_image_convert_f32_to_u8(*res, *temp2); - // clip_image_save_to_bmp(*temp2, "resized_normalized_f32.bmp"); - // clip_image_u8_free(temp2); - // } - - std::vector patches = divide_to_patches_u8(*temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6) - - clip_image_u8_ptr image_original_resize(clip_image_u8_init()); - // bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square - bicubic_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square - patches.insert(patches.begin(), std::move(image_original_resize)); - for (auto & patch : patches) { - clip_image_f32_ptr res(clip_image_f32_init()); - normalize_image_u8_to_f32(*patch, *res, ctx->image_mean, ctx->image_std); - res_imgs->entries.push_back(std::move(res)); - } - - return true; - } else { - temp->nx = img->nx; - temp->ny = img->ny; - temp->buf.resize(img->buf.size()); - memcpy(temp->buf.data(), img->buf.data(), temp->buf.size()); - } - } - - const int nx = temp->nx; - const int ny = temp->ny; - // clip_image_save_to_bmp(*temp, "resized_vanilla.bmp"); - - const int nx2 = ctx->vision_model.hparams.image_size; - const int ny2 = ctx->vision_model.hparams.image_size; - clip_image_f32_ptr res(clip_image_f32_init()); - res->nx = nx2; - res->ny = ny2; - res->buf.resize(3 * nx2 * ny2); - - const float scale = std::max(nx, ny) / (float)ctx->vision_model.hparams.image_size; - - const int nx3 = int(nx / scale + 0.5f); - const int ny3 = int(ny / scale + 0.5f); - - const auto & m3 = ctx->image_mean; // {0.48145466f, 0.4578275f, 0.40821073f}; - const auto & s3 = ctx->image_std; // {0.26862954f, 0.26130258f, 0.27577711f}; - - for (int y = 0; y < ny3; y++) { - for (int x = 0; x < nx3; x++) { - for (int c = 0; c < 3; c++) { - // linear interpolation - const float sx = (x + 0.5f) * scale - 0.5f; - const float sy = (y + 0.5f) * scale - 0.5f; - - const int x0 = std::max(0, (int)std::floor(sx)); - const int y0 = std::max(0, (int)std::floor(sy)); - - const int x1 = std::min(x0 + 1, nx - 1); - const int y1 = std::min(y0 + 1, ny - 1); - - const float dx = sx - x0; - const float dy = sy - y0; - - const int j00 = 3 * (y0 * nx + x0) + c; - const int j01 = 3 * (y0 * nx + x1) + c; - const int j10 = 3 * (y1 * nx + x0) + c; - const int j11 = 3 * (y1 * nx + x1) + c; - - const float v00 = temp->buf[j00]; - const float v01 = temp->buf[j01]; - const float v10 = temp->buf[j10]; - const float v11 = temp->buf[j11]; - - const float v0 = v00 * (1.0f - dx) + v01 * dx; - const float v1 = v10 * (1.0f - dx) + v11 * dx; - - const float v = v0 * (1.0f - dy) + v1 * dy; - - const uint8_t v2 = std::min(std::max(std::round(v), 0.0f), 255.0f); - - const int i = 3 * (y * nx3 + x) + c; - - res->buf[i] = ((float(v2) / 255.0f) - m3[c]) / s3[c]; - } - } - } - - // { - // clip_image_u8 * temp2 = clip_image_u8_init(); - // clip_image_convert_f32_to_u8(*res, *temp2); - // clip_image_save_to_bmp(*temp2, "resized_normalized_f32_vanilla.bmp"); - // clip_image_u8_free(temp2); - // } - // res_imgs.push_back(res); - - res_imgs->entries.push_back(std::move(res)); - - return true; -} - -ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx) { - return ctx->vision_model.image_newline; -} - -void clip_free(clip_ctx * ctx) { - if (ctx == nullptr) { - return; - } - delete ctx; -} - -size_t clip_embd_nbytes(const struct clip_ctx * ctx) { - int extra_tokens = ctx->has_glm_projector ? 2 : 0; - return (clip_n_patches(ctx) + extra_tokens) * clip_n_mmproj_embd(ctx) * sizeof(float); -} - -size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w) { - clip_image_f32 img; - img.nx = img_w; - img.ny = img_h; - return clip_n_patches_by_img(ctx, &img) * clip_n_mmproj_embd(ctx) * sizeof(float); -} - -int32_t clip_get_image_size(const struct clip_ctx * ctx) { - return ctx->vision_model.hparams.image_size; -} - -int32_t clip_get_patch_size(const struct clip_ctx * ctx) { - return ctx->vision_model.hparams.patch_size; -} - -int32_t clip_get_hidden_size(const struct clip_ctx * ctx) { - return ctx->vision_model.hparams.hidden_size; -} - -const char * clip_patch_merge_type(const struct clip_ctx * ctx) { - return ctx->vision_model.hparams.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD ? "spatial_unpad" : "flat"; -} - -const int32_t * clip_image_grid(const struct clip_ctx * ctx) { - if (ctx->vision_model.hparams.image_grid_pinpoints.size()) { - return &ctx->vision_model.hparams.image_grid_pinpoints.front(); - } - return nullptr; -} - -size_t get_clip_image_grid_size(const struct clip_ctx * ctx) { - return ctx->vision_model.hparams.image_grid_pinpoints.size(); -} - -int clip_n_patches(const struct clip_ctx * ctx) { - clip_image_f32 img; - img.nx = ctx->vision_model.hparams.image_size; - img.ny = ctx->vision_model.hparams.image_size; - return clip_n_patches_by_img(ctx, &img); -} - -int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * img) { - const auto & params = ctx->vision_model.hparams; - - int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size); - - if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) { - n_patches /= 4; - } else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { - if (ctx->minicpmv_version == 2) { - n_patches = 96; - } - else if (ctx->minicpmv_version == 3) { - n_patches = 64; - } - else if (ctx->minicpmv_version == 4) { - n_patches = 64; - } - } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { - int patch_size = params.patch_size * 2; - int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0); - int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0); - n_patches = x_patch * y_patch; - } else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { - n_patches = 256; - } - - return n_patches; -} - -static std::vector>> get_1d_sincos_pos_embed_from_grid_new(int embed_dim, const std::vector> & pos) { - assert(embed_dim % 2 == 0); - int H = pos.size(); - int W = pos[0].size(); - - std::vector omega(embed_dim / 2); - for (int i = 0; i < embed_dim / 2; ++i) { - omega[i] = 1.0 / pow(10000.0, static_cast(i) / (embed_dim / 2)); - } - - std::vector>> emb(H, std::vector>(W, std::vector(embed_dim))); - for (int h = 0; h < H; ++h) { - for (int w = 0; w < W; ++w) { - for (int d = 0; d < embed_dim / 2; ++d) { - float out_value = pos[h][w] * omega[d]; - emb[h][w][d] = sin(out_value); - emb[h][w][d + embed_dim / 2] = cos(out_value); - } - } - } - - return emb; -} - -static std::vector>> get_2d_sincos_pos_embed_from_grid(int embed_dim, const std::vector>> & grid) { - assert(embed_dim % 2 == 0); - std::vector>> emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[0]); // (H, W, D/2) - std::vector>> emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[1]); // (H, W, D/2) - - int H = emb_h.size(); - int W = emb_h[0].size(); - std::vector>> emb(H, std::vector>(W, std::vector(embed_dim))); - - for (int h = 0; h < H; ++h) { - for (int w = 0; w < W; ++w) { - for (int d = 0; d < embed_dim / 2; ++d) { - emb[h][w][d] = emb_h[h][w][d]; - emb[h][w][d + embed_dim / 2] = emb_w[h][w][d]; - } - } - } - return emb; -} - -static std::vector> get_2d_sincos_pos_embed(int embed_dim, const std::pair image_size) { - int grid_h_size = image_size.first; - int grid_w_size = image_size.second; - - std::vector grid_h(grid_h_size); - std::vector grid_w(grid_w_size); - - for (int i = 0; i < grid_h_size; ++i) { - grid_h[i] = static_cast(i); - } - for (int i = 0; i < grid_w_size; ++i) { - grid_w[i] = static_cast(i); - } - - std::vector> grid(grid_h_size, std::vector(grid_w_size)); - for (int h = 0; h < grid_h_size; ++h) { - for (int w = 0; w < grid_w_size; ++w) { - grid[h][w] = grid_w[w]; - } - } - std::vector>> grid_2d = {grid, grid}; - for (int h = 0; h < grid_h_size; ++h) { - for (int w = 0; w < grid_w_size; ++w) { - grid_2d[0][h][w] = grid_h[h]; - grid_2d[1][h][w] = grid_w[w]; - } - } - - std::vector>> pos_embed_3d = get_2d_sincos_pos_embed_from_grid(embed_dim, grid_2d); - - int H = image_size.first; - int W = image_size.second; - std::vector> pos_embed_2d(H * W, std::vector(embed_dim)); - for (int h = 0; h < H; ++h) { - for (int w = 0; w < W; ++w) { - pos_embed_2d[w * H + h] = pos_embed_3d[h][w]; - } - } - - return pos_embed_2d; -} - -bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) { - if (!ctx->has_vision_encoder) { - LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__); - return false; - } - - clip_image_f32_batch imgs; - clip_image_f32_ptr img_copy(clip_image_f32_init()); - *img_copy = *img; - imgs.entries.push_back(std::move(img_copy)); - - return clip_image_batch_encode(ctx, n_threads, &imgs, vec); -} - -bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs_c_ptr, float * vec) { - const clip_image_f32_batch & imgs = *imgs_c_ptr; - - if (!ctx->has_vision_encoder) { - LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__); - return false; - } - - int batch_size = imgs.entries.size(); - if (ctx->has_llava_projector) { - GGML_ASSERT(batch_size == 1); // TODO: support multiple images - } - if (ctx->has_minicpmv_projector) { - GGML_ASSERT(batch_size == 1); - } - if (ctx->has_glm_projector) { - GGML_ASSERT(batch_size == 1); - ggml_tensor * boi = ctx->vision_model.boi_w; - ggml_backend_tensor_get(boi,vec,0,ggml_nbytes(boi)); - vec = (float*)(vec+ggml_nelements(boi)); //offset for boi - } - - // build the inference graph - ggml_backend_sched_reset(ctx->sched.get()); - ggml_cgraph * gf = clip_image_build_graph(ctx, imgs, ctx->load_image_size, true); - ggml_backend_sched_alloc_graph(ctx->sched.get(), gf); - - // set inputs - const auto & model = ctx->vision_model; - const auto & hparams = model.hparams; - - const int image_size = hparams.image_size; - int image_size_width = image_size; - int image_size_height = image_size; - if (ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger) { - image_size_width = imgs.entries[0]->nx; - image_size_height = imgs.entries[0]->ny; - } - const int patch_size = hparams.patch_size; - const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size)); - const int num_positions = num_patches + (model.class_embedding ? 1 : 0); - const int pos_w = ctx->load_image_size.width / patch_size; - const int pos_h = ctx->load_image_size.height / patch_size; - - { - struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw"); - float * data = (float *)malloc(ggml_nbytes(inp_raw)); - - for (size_t i = 0; i < imgs.entries.size(); i++) { - const int nx = imgs.entries[i]->nx; - const int ny = imgs.entries[i]->ny; - if (!(ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger)) { - GGML_ASSERT(nx == image_size && ny == image_size); - } - - const int n = nx * ny; - - for (int b = 0; b < batch_size; b++) { - for (int k = 0; k < 3; k++) { - for (int y = 0; y < ny; y++) { - for (int x = 0; x < nx; x++) { - data[(b * 3 * n) + k * n + y * nx + x] = imgs.entries[b]->buf[3 * (y * nx + x) + k]; - } - } - } - } - } - ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw)); - free(data); - } - if (ctx->has_minicpmv_projector) { - { - // inspired from siglip: - // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit - // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316 - struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); - int* positions_data = (int*)malloc(ggml_nbytes(positions)); - int bucket_coords_h[1024]; - int bucket_coords_w[1024]; - for (int i = 0; i < pos_h; i++){ - bucket_coords_h[i] = std::floor(70.0*i/pos_h); - } - for (int i = 0; i < pos_w; i++){ - bucket_coords_w[i] = std::floor(70.0*i/pos_w); - } - for (int i = 0, id = 0; i < pos_h; i++){ - for (int j = 0; j < pos_w; j++){ - positions_data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j]; - } - } - ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); - free(positions_data); - } - - { - // inspired from resampler of Qwen-VL: - // -> https://huggingface.co/Qwen/Qwen-VL/tree/main - // -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23 - struct ggml_tensor * pos_embed = ggml_graph_get_tensor(gf, "pos_embed"); - int embed_dim = 4096; - if (ctx->minicpmv_version == 2) { - embed_dim = 4096; - } - else if (ctx->minicpmv_version == 3) { - embed_dim = 3584; - } - else if (ctx->minicpmv_version == 4) { - embed_dim = 3584; - } - auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h)); - - float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed)); - for(int i=0;i < pos_w * pos_h; ++i){ - for(int j=0; j < embed_dim; ++j){ - pos_embed_data[i * embed_dim + j] = pos_embed_t[i][j]; - } - } - - ggml_backend_tensor_set(pos_embed, pos_embed_data, 0, ggml_nbytes(pos_embed)); - free(pos_embed_data); - } - } - else { - if (model.class_embedding) { - struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings"); - - void* zero_mem = malloc(ggml_nbytes(embeddings)); - memset(zero_mem, 0, ggml_nbytes(embeddings)); - ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings)); - free(zero_mem); - } - - if (ctx->has_qwen2vl_merger) { - struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); - - const int pw = image_size_width / patch_size; - const int ph = image_size_height / patch_size; - int* positions_data = (int*)malloc(ggml_nbytes(positions)); - - int ptr = 0; - for (int y = 0; y < ph; y+=2) - { - for (int x = 0; x < pw; x+=2) - { - for (int dy = 0; dy < 2; dy++) { - for (int dx = 0; dx < 2; dx++) { - positions_data[ptr] = y + dy; - positions_data[num_patches + ptr] = x + dx; - positions_data[num_patches * 2 + ptr] = y + dy; - positions_data[num_patches * 3 + ptr] = x + dx; - ptr++; - } - } - } - } - - ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); - free(positions_data); - } - else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { - // do nothing - } - else { - struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); - - int* positions_data = (int*)malloc(ggml_nbytes(positions)); - for (int i = 0; i < num_positions; i++) { - positions_data[i] = i; - } - ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); - free(positions_data); - - if (!ctx->has_glm_projector) { - struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches"); - // The patches vector is used to get rows to index into the embeds with; - // we should skip dim 0 only if we have CLS to avoid going out of bounds - // when retrieving the rows. - int patch_offset = model.class_embedding ? 1 : 0; - int* patches_data = (int*)malloc(ggml_nbytes(patches)); - for (int i = 0; i < num_patches; i++) { - patches_data[i] = i + patch_offset; - } - ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches)); - free(patches_data); - } - } - } - - ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads); - - auto status = ggml_backend_sched_graph_compute(ctx->sched.get(), gf); - if (status != GGML_STATUS_SUCCESS) { - LOG_ERR("%s: ggml_backend_sched_graph_compute failed with error %d\n", __func__, status); - return false; - } - - // the last node is the embedding tensor - struct ggml_tensor * embeddings = ggml_graph_node(gf, -1); - - // copy the embeddings to the location passed by the user - ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings)); - - if (ctx->has_glm_projector) { - //eoi - ggml_tensor * eoi = ctx->vision_model.eoi_w; - int offset = ggml_nelements(embeddings); - ggml_backend_tensor_get(eoi, vec+offset, 0, ggml_nbytes(eoi)); - } - - return true; -} - -bool clip_model_quantize(const char * fname_inp, const char * fname_out, const int itype) { - assert(itype < GGML_TYPE_COUNT); - ggml_type type = static_cast(itype); - - auto * ctx_clip = clip_init(fname_inp, clip_context_params{ - /* use_gpu */ false, - /* verbosity */ GGML_LOG_LEVEL_ERROR, - }); - - const auto & ctx_src = ctx_clip->ctx_gguf.get(); - const auto & ctx_data = ctx_clip->ctx_data.get(); - - auto * ctx_out = gguf_init_empty(); - gguf_set_kv(ctx_out, ctx_src); - gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION); - gguf_set_val_u32(ctx_out, "general.file_type", itype); - - auto fout = std::ofstream(fname_out, std::ios::binary); - - const int n_tensors = gguf_get_n_tensors(ctx_src); - - for (int i = 0; i < n_tensors; ++i) { - const char * name = gguf_get_tensor_name(ctx_src, i); - struct ggml_tensor * cur = ggml_get_tensor(ctx_data, name); - gguf_add_tensor(ctx_out, cur); - } - - const size_t meta_size = gguf_get_meta_size(ctx_out); - for (size_t i = 0; i < meta_size; ++i) { - fout.put(0); - } - - // regexes of tensor names to be quantized - const std::vector k_names = { - ".*weight", - }; - - std::vector work(512); - std::vector conv_buf(512); - size_t total_size_org = 0; - size_t total_size_new = 0; - - for (int i = 0; i < n_tensors; ++i) { - const std::string name = gguf_get_tensor_name(ctx_src, i); - struct ggml_tensor * cur = ggml_get_tensor(ctx_data, name.c_str()); - - enum ggml_type new_type; - void * new_data; - size_t new_size; - - bool quantize = false; - for (const auto & s : k_names) { - if (std::regex_match(name, std::regex(s))) { - quantize = true; - break; - } - } - - // quantize only 2D tensors and bigger than block size - quantize &= (ggml_n_dims(cur) == 2) && cur->ne[0] > ggml_blck_size(type); - - if (quantize) { - new_type = type; - if (new_type >= GGML_TYPE_Q2_K && name.find("embd") != std::string::npos) { - new_type = GGML_TYPE_Q8_0; // ggml_get_rows needs non K type - // LOG_ERR("%s: quantizing %s to %s\n", __func__, name.c_str(), ggml_type_name(new_type)); - } - const size_t n_elms = ggml_nelements(cur); - float * f32_data; - - switch (cur->type) { - case GGML_TYPE_F32: - f32_data = (float *)cur->data; - break; - case GGML_TYPE_F16: - if (conv_buf.size() < n_elms) { - conv_buf.resize(n_elms); - } - for (size_t j = 0; j < n_elms; ++j) { - conv_buf[j] = ggml_fp16_to_fp32(((ggml_fp16_t *)cur->data)[j]); - } - f32_data = (float *)conv_buf.data(); - break; - default: - LOG_ERR("%s: Please use an input file in f32 or f16\n", __func__); - gguf_free(ctx_out); - return false; - } - - if (work.size() < n_elms * 4) { - work.resize(n_elms * 4); - } - new_data = work.data(); - - new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, n_elms/cur->ne[0], cur->ne[0], nullptr); - } else { - new_type = cur->type; - new_data = cur->data; - new_size = ggml_nbytes(cur); - } - const size_t orig_size = ggml_nbytes(cur); - total_size_org += orig_size; - total_size_new += new_size; - gguf_set_tensor_type(ctx_out, name.c_str(), new_type); - GGML_ASSERT(gguf_get_tensor_size(ctx_out, gguf_find_tensor(ctx_out, name.c_str())) == new_size); - gguf_set_tensor_data(ctx_out, name.c_str(), new_data); - fout.write((const char *)new_data, new_size); - size_t pad = GGML_PAD(new_size, gguf_get_alignment(ctx_out)) - new_size; - for (size_t j = 0; j < pad; ++j) { - fout.put(0); - } - - LOG_INF("%s: n_dims = %d | quantize=%d | size = %f MB -> %f MB\n", name.c_str(), ggml_n_dims(cur), quantize, - orig_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); - } - - // go back to beginning of file and write the updated metadata - fout.seekp(0, std::ios::beg); - std::vector meta(meta_size); - gguf_get_meta_data(ctx_out, meta.data()); - fout.write((const char *)meta.data(), meta_size); - - fout.close(); - - clip_free(ctx_clip); - gguf_free(ctx_out); - - { - LOG_INF("%s: original size = %8.2f MB\n", __func__, total_size_org / 1024.0 / 1024.0); - LOG_INF("%s: quantized size = %8.2f MB\n", __func__, total_size_new / 1024.0 / 1024.0); - } - - return true; -} - -int clip_n_mmproj_embd(const struct clip_ctx * ctx) { - if (ctx->proj_type == PROJECTOR_TYPE_LDP) { - return ctx->vision_model.mm_model_block_1_block_2_1_b->ne[0]; - } - if (ctx->proj_type == PROJECTOR_TYPE_LDPV2) { - return ctx->vision_model.mm_model_peg_0_b->ne[0]; - } - if (ctx->proj_type == PROJECTOR_TYPE_MLP) { - return ctx->vision_model.mm_2_b->ne[0]; - } - if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) { - return ctx->vision_model.mm_3_b->ne[0]; - } - if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { - if (ctx->minicpmv_version == 2) { - return 4096; - } - else if (ctx->minicpmv_version == 3) { - return 3584; - } - else if (ctx->minicpmv_version == 4) { - return 3584; - } - } - if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE){ - return ctx->vision_model.mm_model_mlp_3_w->ne[1]; - } - if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { - return ctx->vision_model.mm_1_b->ne[0]; - } - if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) { - return ctx->vision_model.mm_input_proj_w->ne[0]; - } - - std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type]; - throw std::runtime_error(string_format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str())); -} - -int clip_is_minicpmv(const struct clip_ctx * ctx) { - if (ctx->has_minicpmv_projector) { - return ctx->minicpmv_version; - } - return 0; -} - -bool clip_is_glm(const struct clip_ctx * ctx) { - return ctx->has_glm_projector; -} - -bool clip_is_qwen2vl(const struct clip_ctx * ctx) { - return ctx->has_qwen2vl_merger; -} - -bool clip_is_llava(const struct clip_ctx * ctx) { - return ctx->has_llava_projector; -} - -bool clip_is_gemma3(const struct clip_ctx * ctx) { - return ctx->proj_type == PROJECTOR_TYPE_GEMMA3; -} - -// Determine the number of encoder layers to iterate over -int get_deepest_feature_layer(const struct clip_ctx * ctx) { - // Get the index of the second to last layer; this is the - // default for models that have a llava projector - const auto & hparams = ctx->vision_model.hparams; - int n_layer = hparams.n_layer - 1; - int deepest_feature_layer = -1; - - // Handle other projectors; incrementing here indicates that we - // should use the last encoder layer for the vision features. - if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) { - n_layer += 1; - } - - // If we set explicit vision feature layers, only go up to the deepest one - for (const auto & feature_layer : hparams.vision_feature_layer) { - if (feature_layer > deepest_feature_layer) { - deepest_feature_layer = feature_layer; - } - } - return deepest_feature_layer < 0 ? n_layer : deepest_feature_layer; -} - -bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) { - clip_image_f32 clip_img; - clip_img.buf.resize(h * w * 3); - for (int i = 0; i < h*w*3; i++) - { - clip_img.buf[i] = img[i]; - } - clip_img.nx = w; - clip_img.ny = h; - clip_image_encode(ctx, n_threads, &clip_img, vec); - return true; -} - -// -// API used internally with mtmd -// - -projector_type clip_get_projector_type(const struct clip_ctx * ctx) { - return ctx->proj_type; -} diff --git a/examples/llava/clip.h b/examples/llava/clip.h deleted file mode 100644 index cc133a58d..000000000 --- a/examples/llava/clip.h +++ /dev/null @@ -1,125 +0,0 @@ -#ifndef CLIP_H -#define CLIP_H - -#include "ggml.h" -#include -#include - -#ifdef LLAMA_SHARED -# if defined(_WIN32) && !defined(__MINGW32__) -# ifdef LLAMA_BUILD -# define CLIP_API __declspec(dllexport) -# else -# define CLIP_API __declspec(dllimport) -# endif -# else -# define CLIP_API __attribute__ ((visibility ("default"))) -# endif -#else -# define CLIP_API -#endif - -#ifdef __cplusplus -extern "C" { -#endif - -struct clip_ctx; - -struct clip_image_size { - int width; - int height; -}; - -struct clip_image_u8_batch; -struct clip_image_f32_batch; - -struct clip_context_params { - bool use_gpu; - ggml_log_level verbosity; -}; - -// deprecated, use clip_init -CLIP_API struct clip_ctx * clip_model_load(const char * fname, int verbosity); - -CLIP_API struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_params); - -CLIP_API void clip_free(struct clip_ctx * ctx); - -CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx); -CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w); - -CLIP_API int32_t clip_get_image_size (const struct clip_ctx * ctx); -CLIP_API int32_t clip_get_patch_size (const struct clip_ctx * ctx); -CLIP_API int32_t clip_get_hidden_size(const struct clip_ctx * ctx); - -// TODO: should be enum, not string -CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx); - -CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx); -CLIP_API size_t get_clip_image_grid_size(const struct clip_ctx * ctx); - -CLIP_API int clip_n_patches (const struct clip_ctx * ctx); -CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img); -CLIP_API int clip_n_mmproj_embd (const struct clip_ctx * ctx); - -CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip); -CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size); -CLIP_API struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip); - -CLIP_API struct clip_image_size * clip_image_size_init(); -CLIP_API struct clip_image_u8 * clip_image_u8_init (); -CLIP_API struct clip_image_f32 * clip_image_f32_init(); -CLIP_API struct clip_image_f32_batch * clip_image_f32_batch_init(); // only used by libllava - -// nx, ny are the output image dimensions -CLIP_API unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny); - -CLIP_API void clip_image_size_free (struct clip_image_size * img_size); -CLIP_API void clip_image_u8_free (struct clip_image_u8 * img); -CLIP_API void clip_image_f32_free(struct clip_image_f32 * img); -CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch); -CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch); - -// use for accessing underlay data of clip_image_f32_batch -CLIP_API size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch); // equivalent to batch->size() -CLIP_API size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->nx -CLIP_API size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->ny -CLIP_API clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data - -/** - * Build image from pixels decoded by other libraries instead of stb_image.h for better performance. - * The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes - */ -CLIP_API void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img); - -CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img); - -/** interpret bytes as an image file with length bytes_length, and use the result to populate img */ -CLIP_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img); - -/** preprocess img and store the result in res_imgs, pad_to_square may be overridden to false depending on model configuration */ -CLIP_API bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32_batch * res_imgs ); - -CLIP_API struct ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx); - -CLIP_API bool clip_image_encode (struct clip_ctx * ctx, int n_threads, struct clip_image_f32 * img, float * vec); -CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec); - -CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype); - -CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx); -CLIP_API bool clip_is_glm(const struct clip_ctx * ctx); -CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx); -CLIP_API bool clip_is_llava(const struct clip_ctx * ctx); -CLIP_API bool clip_is_gemma3(const struct clip_ctx * ctx); - -CLIP_API int get_deepest_feature_layer(const struct clip_ctx * ctx); - -CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec); - - -#ifdef __cplusplus -} -#endif - -#endif // CLIP_H diff --git a/examples/llava/gemma3-cli.cpp b/examples/llava/gemma3-cli.cpp deleted file mode 100644 index 91a07e2a8..000000000 --- a/examples/llava/gemma3-cli.cpp +++ /dev/null @@ -1,322 +0,0 @@ -#include "arg.h" -#include "log.h" -#include "common.h" -#include "sampling.h" -#include "llama.h" -#include "ggml.h" -#include "console.h" -#include "chat.h" -#include "mtmd.h" - -#include -#include -#include - -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) -#include -#include -#elif defined (_WIN32) -#define WIN32_LEAN_AND_MEAN -#ifndef NOMINMAX -#define NOMINMAX -#endif -#include -#include -#endif - -static bool g_is_generating = false; - -/** - * Please note that this is NOT a production-ready stuff. - * It is a playground for trying Gemma 3 vision capabilities. - * For contributors: please keep this code simple and easy to understand. - */ - -static void show_additional_info(int /*argc*/, char ** argv) { - LOG( - "Experimental CLI for using Gemma 3 vision model\n\n" - "Usage: %s [options] -m --mmproj --image -p \n\n" - " -m and --mmproj are required\n" - " --image and -p are optional, if NOT provided, the CLI will run in chat mode\n", - argv[0] - ); -} - -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) -static void sigint_handler(int signo) { - if (signo == SIGINT) { - if (g_is_generating) { - g_is_generating = false; - } else { - console::cleanup(); - LOG("\nInterrupted by user\n"); - _exit(130); - } - } -} -#endif - -struct gemma3_context { - mtmd_context_ptr ctx_vision; - common_init_result llama_init; - - llama_model * model; - llama_context * lctx; - const llama_vocab * vocab; - llama_batch batch; - int n_batch; - - // note: we know that gemma3 template is "linear", meaning each turn is completely separated to another - // so here we don't need to keep track of chat history - common_chat_templates_ptr tmpls; - - int n_threads = 1; - llama_pos n_past = 0; - - gemma3_context(common_params & params) : llama_init(common_init_from_params(params)) { - model = llama_init.model.get(); - lctx = llama_init.context.get(); - vocab = llama_model_get_vocab(model); - n_threads = params.cpuparams.n_threads; - batch = llama_batch_init(params.n_batch, 0, 1); - n_batch = params.n_batch; - tmpls = common_chat_templates_init(model, params.chat_template); - init_vision_context(params); - } - - void init_vision_context(common_params & params) { - const char * clip_path = params.mmproj.path.c_str(); - ctx_vision.reset(mtmd_init_from_file(clip_path, model, mtmd_context_params{ - /* use_gpu */ true, - /* timings */ true, - /* n_threads */ params.cpuparams.n_threads, - /* verbosity */ GGML_LOG_LEVEL_INFO, - })); - if (!ctx_vision.get()) { - LOG_ERR("Failed to load vision model from %s\n", clip_path); - exit(1); - } - } -}; - -struct decode_embd_batch { - std::vector pos; - std::vector n_seq_id; - std::vector seq_id_0; - std::vector seq_ids; - std::vector logits; - llama_batch batch; - decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { - pos .resize(n_tokens); - n_seq_id.resize(n_tokens); - seq_ids .resize(n_tokens + 1); - logits .resize(n_tokens); - seq_id_0.resize(1); - seq_id_0[0] = seq_id; - seq_ids [n_tokens] = nullptr; - batch = { - /*n_tokens =*/ n_tokens, - /*tokens =*/ nullptr, - /*embd =*/ embd, - /*pos =*/ pos.data(), - /*n_seq_id =*/ n_seq_id.data(), - /*seq_id =*/ seq_ids.data(), - /*logits =*/ logits.data(), - }; - for (int i = 0; i < n_tokens; i++) { - batch.pos [i] = pos_0 + i; - batch.n_seq_id[i] = 1; - batch.seq_id [i] = seq_id_0.data(); - batch.logits [i] = false; - } - } -}; - -static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_predict) { - for (int i = 0; i < n_predict; i++) { - if (i > n_predict || !g_is_generating) { - printf("\n"); - break; - } - - llama_token token_id = common_sampler_sample(smpl, ctx.lctx, -1); - common_sampler_accept(smpl, token_id, true); - - if (llama_vocab_is_eog(ctx.vocab, token_id)) { - printf("\n"); - break; // end of generation - } - - printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str()); - fflush(stdout); - - // eval the token - common_batch_clear(ctx.batch); - common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true); - if (llama_decode(ctx.lctx, ctx.batch)) { - LOG_ERR("failed to decode token\n"); - return 1; - } - } - return 0; -} - -static int eval_message(gemma3_context & ctx, common_chat_msg & msg, std::vector & images_fname, bool add_bos = false) { - std::vector bitmaps; - - common_chat_templates_inputs tmpl_inputs; - tmpl_inputs.messages = {msg}; - tmpl_inputs.add_generation_prompt = true; - tmpl_inputs.use_jinja = false; // jinja is buggy here - auto formatted_chat = common_chat_templates_apply(ctx.tmpls.get(), tmpl_inputs); - LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str()); - - for (auto & fname : images_fname) { - mtmd_bitmap bitmap; - if (mtmd_helper_bitmap_init_from_file(fname.c_str(), bitmap)) { - LOG_ERR("Unable to load image %s\n", fname.c_str()); - return 2; // image not found - } - bitmaps.push_back(std::move(bitmap)); - } - - mtmd_input_text text; - text.text = formatted_chat.prompt; - text.add_special = add_bos; - text.parse_special = true; - mtmd_input_chunks_ptr chunks(mtmd_tokenize(ctx.ctx_vision.get(), text, bitmaps)); - if (chunks == nullptr) { - LOG_ERR("Unable to tokenize prompt\n"); - return 1; - } - - if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks.get(), ctx.n_past, 0, ctx.n_batch)) { - LOG_ERR("Unable to eval prompt\n"); - return 1; - } - - ctx.n_past += mtmd_helper_get_n_tokens(chunks.get()); - - return 0; -} - -int main(int argc, char ** argv) { - ggml_time_init(); - - common_params params; - params.sampling.temp = 0.2; // lower temp by default for better quality - - if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, show_additional_info)) { - return 1; - } - - common_init(); - - if (params.mmproj.path.empty()) { - show_additional_info(argc, argv); - return 1; - } - - gemma3_context ctx(params); - printf("%s: %s\n", __func__, params.model.path.c_str()); - - bool is_single_turn = !params.prompt.empty() && !params.image.empty(); - - struct common_sampler * smpl = common_sampler_init(ctx.model, params.sampling); - int n_predict = params.n_predict < 0 ? INT_MAX : params.n_predict; - - // ctrl+C handling - { -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) - struct sigaction sigint_action; - sigint_action.sa_handler = sigint_handler; - sigemptyset (&sigint_action.sa_mask); - sigint_action.sa_flags = 0; - sigaction(SIGINT, &sigint_action, NULL); -#elif defined (_WIN32) - auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { - return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false; - }; - SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); -#endif - } - - if (is_single_turn) { - g_is_generating = true; - if (params.prompt.find("<__image__>") == std::string::npos) { - params.prompt += " <__image__>"; - } - common_chat_msg msg; - msg.role = "user"; - msg.content = params.prompt; - if (eval_message(ctx, msg, params.image, true)) { - return 1; - } - if (generate_response(ctx, smpl, n_predict)) { - return 1; - } - - } else { - LOG("\n Running in chat mode, available commands:"); - LOG("\n /image load an image"); - LOG("\n /clear clear the chat history"); - LOG("\n /quit or /exit exit the program"); - LOG("\n"); - - bool is_first_msg = true; - std::vector images_fname; - std::string content; - - while (true) { - g_is_generating = false; - LOG("\n> "); - console::set_display(console::user_input); - std::string line; - console::readline(line, false); - console::set_display(console::reset); - line = string_strip(line); - if (line.empty()) { - continue; - } - if (line == "/quit" || line == "/exit") { - break; - } - if (line == "/clear") { - ctx.n_past = 0; - llama_kv_self_seq_rm(ctx.lctx, 0, 1, -1); // keep BOS - LOG("Chat history cleared\n\n"); - continue; - } - g_is_generating = true; - if (line.find("/image") == 0) { - std::string image = line.substr(7); - images_fname.push_back(string_strip(image)); - content += "<__image__>"; - continue; - } else { - content += line; - } - common_chat_msg msg; - msg.role = "user"; - msg.content = content; - int ret = eval_message(ctx, msg, images_fname, is_first_msg); - if (ret == 2) { - // non-fatal error - images_fname.clear(); - content.clear(); - continue; - } - if (ret) { - return 1; - } - if (generate_response(ctx, smpl, n_predict)) { - return 1; - } - images_fname.clear(); - content.clear(); - is_first_msg = false; - } - } - - return 0; -} diff --git a/examples/llava/gemma3_convert_encoder_to_gguf.py b/examples/llava/gemma3_convert_encoder_to_gguf.py deleted file mode 100644 index 241b526b9..000000000 --- a/examples/llava/gemma3_convert_encoder_to_gguf.py +++ /dev/null @@ -1,307 +0,0 @@ -import gguf -import argparse -import logging -import sys -import torch -import json -import os -import numpy as np -from typing import cast, ContextManager, Any, Iterator -from pathlib import Path -from torch import Tensor - -logger = logging.getLogger("gemma3-mmproj") - - -# (copied from convert_hf_to_gguf.py) -# tree of lazy tensors -class LazyTorchTensor(gguf.LazyBase): - _tensor_type = torch.Tensor - # to keep the type-checker happy - dtype: torch.dtype - shape: torch.Size - - # only used when converting a torch.Tensor to a np.ndarray - _dtype_map: dict[torch.dtype, type] = { - torch.float16: np.float16, - torch.float32: np.float32, - } - - # used for safetensors slices - # ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046 - # TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734 - _dtype_str_map: dict[str, torch.dtype] = { - "F64": torch.float64, - "F32": torch.float32, - "BF16": torch.bfloat16, - "F16": torch.float16, - # "U64": torch.uint64, - "I64": torch.int64, - # "U32": torch.uint32, - "I32": torch.int32, - # "U16": torch.uint16, - "I16": torch.int16, - "U8": torch.uint8, - "I8": torch.int8, - "BOOL": torch.bool, - "F8_E4M3": torch.float8_e4m3fn, - "F8_E5M2": torch.float8_e5m2, - } - - def numpy(self) -> gguf.LazyNumpyTensor: - dtype = self._dtype_map[self.dtype] - return gguf.LazyNumpyTensor( - meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape), - args=(self,), - func=(lambda s: s.numpy()) - ) - - @classmethod - def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) -> Tensor: - return torch.empty(size=shape, dtype=dtype, device="meta") - - @classmethod - def from_safetensors_slice(cls, st_slice: Any) -> Tensor: - dtype = cls._dtype_str_map[st_slice.get_dtype()] - shape: tuple[int, ...] = tuple(st_slice.get_shape()) - lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:]) - return cast(torch.Tensor, lazy) - - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - del types # unused - - if kwargs is None: - kwargs = {} - - if func is torch.Tensor.numpy: - return args[0].numpy() - - return cls._wrap_fn(func)(*args, **kwargs) - - -class Gemma3VisionTower: - hparams: dict - gguf_writer: gguf.GGUFWriter - fname_out: Path - ftype: gguf.LlamaFileType - - @staticmethod - def load_hparams(dir_model: Path): - with open(dir_model / "config.json", "r", encoding="utf-8") as f: - return json.load(f) - - @staticmethod - def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]: - part_names: list[str] = [] - for filename in os.listdir(dir_model): - if filename.startswith(prefix) and filename.endswith(suffix): - part_names.append(filename) - part_names.sort() - return part_names - - def __init__(self, - dir_model: Path, - fname_out: Path, - ftype: gguf.LlamaFileType, - is_big_endian: bool,): - hparams = Gemma3VisionTower.load_hparams(dir_model) - self.hparams = hparams - self.fname_out = fname_out - self.ftype = ftype - endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE - self.gguf_writer = gguf.GGUFWriter(path=None, arch="clip", endianess=endianess) - - text_config = hparams["text_config"] - vision_config = hparams["vision_config"] - - assert hparams["architectures"][0] == "Gemma3ForConditionalGeneration" - assert text_config is not None - assert vision_config is not None - - self.gguf_writer.add_string ("clip.projector_type", "gemma3") - self.gguf_writer.add_bool ("clip.has_text_encoder", False) - self.gguf_writer.add_bool ("clip.has_vision_encoder", True) - self.gguf_writer.add_bool ("clip.has_llava_projector", False) # legacy - self.gguf_writer.add_uint32 ("clip.vision.image_size", vision_config["image_size"]) - self.gguf_writer.add_uint32 ("clip.vision.patch_size", vision_config["patch_size"]) - self.gguf_writer.add_uint32 ("clip.vision.embedding_length", vision_config["hidden_size"]) - self.gguf_writer.add_uint32 ("clip.vision.feed_forward_length", vision_config["intermediate_size"]) - self.gguf_writer.add_uint32 ("clip.vision.projection_dim", text_config["hidden_size"]) - self.gguf_writer.add_uint32 ("clip.vision.block_count", vision_config["num_hidden_layers"]) - self.gguf_writer.add_uint32 ("clip.vision.attention.head_count", vision_config["num_attention_heads"]) - self.gguf_writer.add_float32("clip.vision.attention.layer_norm_epsilon", vision_config.get("layer_norm_eps", 1e-6)) - # default values taken from HF tranformers code - self.gguf_writer.add_array ("clip.vision.image_mean", [0.5, 0.5, 0.5]) - self.gguf_writer.add_array ("clip.vision.image_std", [0.5, 0.5, 0.5]) - self.gguf_writer.add_bool ("clip.use_gelu", True) - - # load tensors - for name, data_torch in self.get_tensors(dir_model): - # convert any unsupported data types to float32 - if data_torch.dtype not in (torch.float16, torch.float32): - data_torch = data_torch.to(torch.float32) - self.add_tensor(name, data_torch) - - def get_tensors(self, dir_model: Path) -> Iterator[tuple[str, Tensor]]: - part_names = Gemma3VisionTower.get_model_part_names(dir_model, "model", ".safetensors") - tensor_names_from_parts: set[str] = set() - for part_name in part_names: - logger.info(f"gguf: loading model part '{part_name}'") - from safetensors import safe_open - ctx = cast(ContextManager[Any], safe_open(dir_model / part_name, framework="pt", device="cpu")) - with ctx as model_part: - tensor_names_from_parts.update(model_part.keys()) - - for name in model_part.keys(): - data = model_part.get_slice(name) - data = LazyTorchTensor.from_safetensors_slice(data) - yield name, data - - def add_tensor(self, name: str, data_torch: Tensor): - is_1d = len(data_torch.shape) == 1 - is_embd = ".embeddings." in name - old_dtype = data_torch.dtype - can_quantize = not is_1d and not is_embd - data_qtype = gguf.GGMLQuantizationType.F32 - - # this is to support old checkpoint - # TODO: remove this when we have the final model - name = name.replace("vision_model.vision_model.", "vision_tower.vision_model.") - name = name.replace("multimodal_projector.", "multi_modal_projector.") - - # filter only vision tensors - if not name.startswith("vision_tower.vision_model.") and not name.startswith("multi_modal_projector."): - return - # prefix - name = name.replace("vision_tower.vision_model.encoder.layers.", "v.blk.") - name = name.replace("vision_tower.vision_model.", "v.") - # projector and input embd - name = name.replace(".embeddings.patch_embedding.", ".patch_embd.") - name = name.replace(".embeddings.position_embedding.", ".position_embd.") - name = name.replace( - "multi_modal_projector.mm_input_projection_weight", - "mm.input_projection.weight" - ) - name = name.replace( - "multi_modal_projector.mm_soft_emb_norm.weight", - "mm.soft_emb_norm.weight" - ) - name = name.replace("post_layernorm.", "post_ln.") - # each block - name = name.replace(".self_attn.k_proj.", ".attn_k.") - name = name.replace(".self_attn.v_proj.", ".attn_v.") - name = name.replace(".self_attn.q_proj.", ".attn_q.") - name = name.replace(".self_attn.out_proj.", ".attn_out.") - name = name.replace(".layer_norm1.", ".ln1.") - name = name.replace(".layer_norm2.", ".ln2.") - name = name.replace(".mlp.fc1.", ".ffn_down.") - name = name.replace(".mlp.fc2.", ".ffn_up.") - - if can_quantize: - if self.ftype == gguf.LlamaFileType.ALL_F32: - data_qtype = gguf.GGMLQuantizationType.F32 - elif self.ftype == gguf.LlamaFileType.MOSTLY_F16: - data_qtype = gguf.GGMLQuantizationType.F16 - elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16: - data_qtype = gguf.GGMLQuantizationType.BF16 - elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0: - data_qtype = gguf.GGMLQuantizationType.Q8_0 - else: - raise ValueError(f"Unsupported file type: {self.ftype}") - - # corrent norm value ; only this "soft_emb_norm" need to be corrected as it's part of Gemma projector - # the other norm values are part of SigLIP model, and they are already correct - # ref code: Gemma3RMSNorm - if "soft_emb_norm.weight" in name: - logger.info(f"Correcting norm value for '{name}'") - data_torch = data_torch + 1 - - data = data_torch.numpy() - - try: - data = gguf.quants.quantize(data, data_qtype) - except Exception as e: - logger.error(f"Error quantizing tensor '{name}': {e}, fallback to F16") - data_qtype = gguf.GGMLQuantizationType.F16 - data = gguf.quants.quantize(data, data_qtype) - - # reverse shape to make it similar to the internal ggml dimension order - shape_str = f"{{{', '.join(str(n) for n in reversed(data_torch.shape))}}}" - logger.info(f"{f'%-32s' % f'{name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") - - self.gguf_writer.add_tensor(name, data, raw_dtype=data_qtype) - - def write(self): - self.gguf_writer.write_header_to_file(path=self.fname_out) - self.gguf_writer.write_kv_data_to_file() - self.gguf_writer.write_tensors_to_file(progress=True) - self.gguf_writer.close() - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Convert Gemma 3 vision tower safetensors to GGUF format",) - parser.add_argument( - "--outfile", type=Path, default="mmproj.gguf", - help="path to write to", - ) - parser.add_argument( - "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16", - help="output format", - ) - parser.add_argument( - "--bigendian", action="store_true", - help="model is executed on big endian machine", - ) - parser.add_argument( - "model", type=Path, - help="directory containing model file", - nargs="?", - ) - parser.add_argument( - "--verbose", action="store_true", - help="increase output verbosity", - ) - - args = parser.parse_args() - if args.model is None: - parser.error("the following arguments are required: model") - return args - - -def main() -> None: - args = parse_args() - - if args.verbose: - logging.basicConfig(level=logging.DEBUG) - else: - logging.basicConfig(level=logging.INFO) - - dir_model = args.model - - if not dir_model.is_dir(): - logger.error(f'Error: {args.model} is not a directory') - sys.exit(1) - - ftype_map: dict[str, gguf.LlamaFileType] = { - "f32": gguf.LlamaFileType.ALL_F32, - "f16": gguf.LlamaFileType.MOSTLY_F16, - "bf16": gguf.LlamaFileType.MOSTLY_BF16, - "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, - } - - logger.info(f"Loading model: {dir_model.name}") - - with torch.inference_mode(): - gemma3_vision_tower = Gemma3VisionTower( - dir_model=dir_model, - fname_out=args.outfile, - ftype=ftype_map[args.outtype], - is_big_endian=args.bigendian, - ) - gemma3_vision_tower.write() - - -if __name__ == '__main__': - main() - diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp deleted file mode 100644 index 0fe0e333a..000000000 --- a/examples/llava/llava-cli.cpp +++ /dev/null @@ -1,332 +0,0 @@ -#include "arg.h" -#include "base64.hpp" -#include "log.h" -#include "common.h" -#include "sampling.h" -#include "clip.h" -#include "llava.h" -#include "llama.h" -#include "ggml.h" - -#include -#include -#include -#include - -static bool eval_tokens(struct llama_context * ctx_llama, std::vector tokens, int n_batch, int * n_past) { - int N = (int) tokens.size(); - for (int i = 0; i < N; i += n_batch) { - int n_eval = (int) tokens.size() - i; - if (n_eval > n_batch) { - n_eval = n_batch; - } - if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) { - LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); - return false; - } - *n_past += n_eval; - } - return true; -} - -static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past) { - std::vector tokens; - tokens.push_back(id); - return eval_tokens(ctx_llama, tokens, 1, n_past); -} - -static bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, bool add_bos){ - std::string str2 = str; - std::vector embd_inp = common_tokenize(ctx_llama, str2, add_bos, true); - eval_tokens(ctx_llama, embd_inp, n_batch, n_past); - return true; -} - -static const char * sample(struct common_sampler * smpl, - struct llama_context * ctx_llama, - int * n_past) { - const llama_token id = common_sampler_sample(smpl, ctx_llama, -1); - common_sampler_accept(smpl, id, true); - - const llama_model * model = llama_get_model(ctx_llama); - const llama_vocab * vocab = llama_model_get_vocab(model); - - static std::string ret; - if (llama_vocab_is_eog(vocab, id)) { - ret = ""; - } else { - ret = common_token_to_piece(ctx_llama, id); - } - eval_id(ctx_llama, id, n_past); - return ret.c_str(); -} - -static const char* IMG_BASE64_TAG_BEGIN = ""; - -static void find_image_tag_in_prompt(const std::string& prompt, size_t& begin_out, size_t& end_out) { - begin_out = prompt.find(IMG_BASE64_TAG_BEGIN); - end_out = prompt.find(IMG_BASE64_TAG_END, (begin_out == std::string::npos) ? 0UL : begin_out); -} - -static bool prompt_contains_image(const std::string& prompt) { - size_t begin, end; - find_image_tag_in_prompt(prompt, begin, end); - return (begin != std::string::npos); -} - -// replaces the base64 image tag in the prompt with `replacement` -static llava_image_embed * llava_image_embed_make_with_prompt_base64(struct clip_ctx * ctx_clip, int n_threads, const std::string& prompt) { - size_t img_base64_str_start, img_base64_str_end; - find_image_tag_in_prompt(prompt, img_base64_str_start, img_base64_str_end); - if (img_base64_str_start == std::string::npos || img_base64_str_end == std::string::npos) { - LOG_ERR("%s: invalid base64 image tag. must be %s%s\n", __func__, IMG_BASE64_TAG_BEGIN, IMG_BASE64_TAG_END); - return NULL; - } - - auto base64_bytes_start = img_base64_str_start + strlen(IMG_BASE64_TAG_BEGIN); - auto base64_bytes_count = img_base64_str_end - base64_bytes_start; - auto base64_str = prompt.substr(base64_bytes_start, base64_bytes_count ); - - auto required_bytes = base64::required_encode_size(base64_str.size()); - auto img_bytes = std::vector(required_bytes); - base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin()); - - auto embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, img_bytes.data(), img_bytes.size()); - if (!embed) { - LOG_ERR("%s: could not load image from base64 string.\n", __func__); - return NULL; - } - - return embed; -} - -static std::string remove_image_from_prompt(const std::string& prompt, const char * replacement = "") { - size_t begin, end; - find_image_tag_in_prompt(prompt, begin, end); - if (begin == std::string::npos || end == std::string::npos) { - return prompt; - } - auto pre = prompt.substr(0, begin); - auto post = prompt.substr(end + strlen(IMG_BASE64_TAG_END)); - return pre + replacement + post; -} - -struct llava_context { - struct clip_ctx * ctx_clip = NULL; - struct llama_context * ctx_llama = NULL; - struct llama_model * model = NULL; -}; - -static void print_usage(int, char ** argv) { - LOG("\n example usage:\n"); - LOG("\n %s -m --mmproj --image --image [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]); - LOG("\n note: a lower temperature value like 0.1 is recommended for better quality.\n"); -} - -static struct llava_image_embed * load_image(llava_context * ctx_llava, common_params * params, const std::string & fname) { - - // load and preprocess the image - llava_image_embed * embed = NULL; - auto prompt = params->prompt; - if (prompt_contains_image(prompt)) { - if (!params->image.empty()) { - LOG_INF("using base64 encoded image instead of command line image path\n"); - } - embed = llava_image_embed_make_with_prompt_base64(ctx_llava->ctx_clip, params->cpuparams.n_threads, prompt); - if (!embed) { - LOG_ERR("%s: can't load image from prompt\n", __func__); - return NULL; - } - params->prompt = remove_image_from_prompt(prompt); - } else { - embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->cpuparams.n_threads, fname.c_str()); - if (!embed) { - fprintf(stderr, "%s: is %s really an image file?\n", __func__, fname.c_str()); - return NULL; - } - } - - return embed; -} - -static void process_prompt(struct llava_context * ctx_llava, struct llava_image_embed * image_embed, common_params * params, const std::string & prompt) { - int n_past = 0; - - const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict; - - std::string system_prompt, user_prompt; - size_t image_pos = prompt.find(""); - if (image_pos != std::string::npos) { - // new templating mode: Provide the full prompt including system message and use as a placeholder for the image - system_prompt = prompt.substr(0, image_pos); - user_prompt = prompt.substr(image_pos + std::string("").length()); - LOG_INF("system_prompt: %s\n", system_prompt.c_str()); - if (params->verbose_prompt) { - auto tmp = common_tokenize(ctx_llava->ctx_llama, system_prompt, true, true); - for (int i = 0; i < (int) tmp.size(); i++) { - LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str()); - } - } - LOG_INF("user_prompt: %s\n", user_prompt.c_str()); - if (params->verbose_prompt) { - auto tmp = common_tokenize(ctx_llava->ctx_llama, user_prompt, true, true); - for (int i = 0; i < (int) tmp.size(); i++) { - LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str()); - } - } - } else { - // llava-1.5 native mode - system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:"; - user_prompt = prompt + "\nASSISTANT:"; - if (params->verbose_prompt) { - auto tmp = common_tokenize(ctx_llava->ctx_llama, user_prompt, true, true); - for (int i = 0; i < (int) tmp.size(); i++) { - LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str()); - } - } - } - - eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, true); - llava_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past); - eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false); - - // generate the response - - LOG("\n"); - - struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling); - if (!smpl) { - LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__); - exit(1); - } - - std::string response = ""; - for (int i = 0; i < max_tgt_len; i++) { - const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past); - response += tmp; - if (strcmp(tmp, "") == 0) break; - if (strstr(tmp, "###")) break; // Yi-VL behavior - LOG("%s", tmp); - if (strstr(response.c_str(), "<|im_end|>")) break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works) - if (strstr(response.c_str(), "<|im_start|>")) break; // Yi-34B llava-1.6 - if (strstr(response.c_str(), "USER:")) break; // mistral llava-1.6 - - fflush(stdout); - } - - common_sampler_free(smpl); - LOG("\n"); -} - -static struct llama_model * llava_init(common_params * params) { - llama_backend_init(); - llama_numa_init(params->numa); - - llama_model_params model_params = common_model_params_to_llama(*params); - - llama_model * model = llama_model_load_from_file(params->model.path.c_str(), model_params); - if (model == NULL) { - LOG_ERR("%s: unable to load model\n" , __func__); - return NULL; - } - return model; -} - -static struct llava_context * llava_init_context(common_params * params, llama_model * model) { - const char * clip_path = params->mmproj.path.c_str(); - - auto prompt = params->prompt; - if (prompt.empty()) { - prompt = "describe the image in detail."; - } - - auto ctx_clip = clip_model_load(clip_path, GGML_LOG_LEVEL_INFO); - - llama_context_params ctx_params = common_context_params_to_llama(*params); - ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings - - llama_context * ctx_llama = llama_init_from_model(model, ctx_params); - - if (ctx_llama == NULL) { - LOG_ERR("%s: failed to create the llama_context\n" , __func__); - return NULL; - } - - auto * ctx_llava = (struct llava_context *)malloc(sizeof(llava_context)); - - ctx_llava->ctx_llama = ctx_llama; - ctx_llava->ctx_clip = ctx_clip; - ctx_llava->model = model; - return ctx_llava; -} - -static void llava_free(struct llava_context * ctx_llava) { - if (ctx_llava->ctx_clip) { - clip_free(ctx_llava->ctx_clip); - ctx_llava->ctx_clip = NULL; - } - - llama_free(ctx_llava->ctx_llama); - llama_model_free(ctx_llava->model); - llama_backend_free(); -} - -int main(int argc, char ** argv) { - ggml_time_init(); - - common_params params; - - if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, print_usage)) { - return 1; - } - - common_init(); - - if (params.mmproj.path.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) { - print_usage(argc, argv); - return 1; - } - - auto * model = llava_init(¶ms); - if (model == NULL) { - fprintf(stderr, "%s: error: failed to init llava model\n", __func__); - return 1; - } - - if (prompt_contains_image(params.prompt)) { - auto * ctx_llava = llava_init_context(¶ms, model); - - auto * image_embed = load_image(ctx_llava, ¶ms, ""); - - // process the prompt - process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); - - llama_perf_context_print(ctx_llava->ctx_llama); - llava_image_embed_free(image_embed); - ctx_llava->model = NULL; - llava_free(ctx_llava); - } else { - for (auto & image : params.image) { - auto * ctx_llava = llava_init_context(¶ms, model); - - auto * image_embed = load_image(ctx_llava, ¶ms, image); - if (!image_embed) { - LOG_ERR("%s: failed to load image %s. Terminating\n\n", __func__, image.c_str()); - return 1; - } - - // process the prompt - process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); - - llama_perf_context_print(ctx_llava->ctx_llama); - llava_image_embed_free(image_embed); - ctx_llava->model = NULL; - llava_free(ctx_llava); - } - } - - llama_model_free(model); - - return 0; -} diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp deleted file mode 100644 index 03a22cbb4..000000000 --- a/examples/llava/llava.cpp +++ /dev/null @@ -1,585 +0,0 @@ -#include "clip.h" -#include "llava.h" - -#include "llama.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#if defined(LLAVA_LOG_OFF) -# define LOG_INF(...) -# define LOG_WRN(...) -# define LOG_ERR(...) -# define LOG_DBG(...) -#else // defined(LLAVA_LOG_OFF) -# define LOG_INF(...) do { fprintf(stdout, __VA_ARGS__); } while (0) -# define LOG_WRN(...) do { fprintf(stderr, __VA_ARGS__); } while (0) -# define LOG_ERR(...) do { fprintf(stderr, __VA_ARGS__); } while (0) -# define LOG_DBG(...) do { fprintf(stdout, __VA_ARGS__); } while (0) -#endif // defined(LLAVA_LOG_OFF) - -// RGB uint8 image -struct clip_image_u8 { - int nx; - int ny; - - std::vector buf; -}; - -// RGB float32 image (NHWC) -// Memory layout: RGBRGBRGB... -struct clip_image_f32 { - int nx; - int ny; - - std::vector buf; -}; - -struct clip_image_grid_shape { - int first; - int second; -}; - -// convenience cpp wrapper -struct clip_image_f32_batch_deleter { - void operator()(clip_image_f32_batch * val) { clip_image_f32_batch_free(val); } -}; -typedef std::unique_ptr clip_image_f32_batch_ptr; - -struct clip_image_size_deleter { - void operator()(clip_image_f32_batch * val) { clip_image_f32_batch_free(val); } -}; -typedef std::unique_ptr clip_image_size_ptr; - -/** - * Selects the best resolution from a list of possible resolutions based on the original size. - * - * @param original_size The original size of the image in the format (width, height). - * @param possible_resolutions A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. - * @return The best fit resolution in the format (width, height). - */ -static std::pair select_best_resolution(const std::pair& original_size, const std::vector>& possible_resolutions) { - int original_width = original_size.first; - int original_height = original_size.second; - - std::pair best_fit; - int max_effective_resolution = 0; - int min_wasted_resolution = std::numeric_limits::max(); - - for (const auto& resolution : possible_resolutions) { - int width = resolution.first; - int height = resolution.second; - float scale = std::min(static_cast(width) / original_width, static_cast(height) / original_height); - int downscaled_width = static_cast(original_width * scale); - int downscaled_height = static_cast(original_height * scale); - int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height); - int wasted_resolution = (width * height) - effective_resolution; - // LOG_DBG("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution); - if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) { - max_effective_resolution = effective_resolution; - min_wasted_resolution = wasted_resolution; - best_fit = resolution; - } - } - - return best_fit; -} - -/** - * @brief Get the anyres image grid shape object - * - * @param image_size - * @param grid_pinpoints - * @param image_patch_size - * @return - */ -static struct clip_image_grid_shape get_anyres_image_grid_shape(const std::pair & image_size, const std::vector> & grid_pinpoints, int image_patch_size) { - /** - Conversion from gguf flat array to vector: - std::vector> possible_resolutions; - for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i+=2) { - possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]}); - } - */ - auto best_resolution = select_best_resolution(image_size, grid_pinpoints); - return {best_resolution.first / image_patch_size, best_resolution.second / image_patch_size}; -} - -// Take the image segments in a grid configuration and return the embeddings and the number of embeddings into preallocated memory (image_embd_out) -static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector & image_embd_v, struct clip_image_grid_shape grid_shape, float * image_embd_out, int * n_img_pos_out) { - struct { - struct ggml_context * ctx; - } model; - - const int32_t image_size = clip_get_image_size(ctx_clip); - const int32_t patch_size = clip_get_patch_size(ctx_clip); - - int32_t num_patches_per_side = image_size / patch_size; // 336 / 14 = 24 - used for embedding-patching boxes (24*24 = 576 patches) - - int num_patches_width = grid_shape.first; // grid 1-4 - int num_patches_height = grid_shape.second; // grid 1-4 - - const size_t num_images = num_patches_width * num_patches_height + 1; - - // TODO: size calculation is not calculated - it's only tens of MB - size_t ctx_size = 0; - - { - ctx_size += clip_embd_nbytes(ctx_clip) * num_images * 8; // image_features - ctx_size += 1024*1024 * ggml_type_size(GGML_TYPE_F32); - } - - struct ggml_init_params params { - /*.mem_size =*/ ctx_size, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ false, // NOTE: this should be false when using the legacy API - }; - - // Python reference code for full unpad: - /* - base_image_feature = image_feature[0] - image_feature = image_feature[1:] - image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() - image_feature = image_feature.flatten(1, 2).flatten(2, 3) - image_feature = unpad_image(image_feature, image_sizes[image_idx]) - image_feature = torch.cat(( - image_feature, - self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1) - ), dim=-1) - image_feature = image_feature.flatten(1, 2).transpose(0, 1) - image_feature = torch.cat((base_image_feature, image_feature), dim=0) - */ - // We now have two options: unpad or no unpad. Unpad removes tokens for faster llm eval. - // In terms of result quality it appears to make no difference, so we'll start with the easier approach given 5D tensors are not supported in ggml yet. - // Without unpad we have to split the sub-image embeddings into patches of 24 features each and permute them. - // Once all images are processed to prepended the base_image_features without any changes. - - // Pytorch reference simplified, modified for ggml compatibility - confirmed identical output in python (for a 2x2 grid image (676x676 scaling)) - /* - image_feature = image_feature.view(2, 2, 24, 24, 4096) - image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() - image_feature = image_feature.view(2, 24, 2, 24, 4096) - image_feature = image_feature.flatten(0, 3) - - // Reshape to 4D tensor by merging the last two dimensions - image_feature = image_feature.view(2, 2, 24, 24*4096) - image_feature = image_feature.permute(0, 2, 1, 3).contiguous() - image_feature = image_feature.view(-1, 4096) - */ - - model.ctx = ggml_init(params); - - struct ggml_tensor * image_features = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, clip_n_mmproj_embd(ctx_clip), clip_n_patches(ctx_clip), num_images - 1); // example: 4096 x 576 x 4 - // ggml_tensor_printf(image_features,"image_features",__LINE__,false,false); - // fill it with the image embeddings, ignoring the base - for (size_t i = 1; i < num_images; i++) { - size_t offset = (i-1) * clip_embd_nbytes(ctx_clip); - memcpy((uint8_t *)(image_features->data) + offset, image_embd_v[i], clip_embd_nbytes(ctx_clip)); - } - - struct ggml_cgraph * gf = ggml_new_graph(model.ctx); - size_t size_ele = ggml_type_size(GGML_TYPE_F32); - - struct ggml_tensor *image_features_patchview = ggml_view_4d(model.ctx, image_features, - num_patches_per_side * clip_n_mmproj_embd(ctx_clip), - num_patches_per_side, - num_patches_width, - num_patches_height, - size_ele * num_patches_per_side * clip_n_mmproj_embd(ctx_clip), - size_ele * num_patches_per_side * clip_n_mmproj_embd(ctx_clip) * num_patches_per_side, - size_ele * num_patches_per_side * clip_n_mmproj_embd(ctx_clip) * num_patches_per_side * num_patches_width, 0); - // ggml_tensor_printf(image_features_patchview,"image_features_patchview",__LINE__,false,false); - struct ggml_tensor *permuted_cont = ggml_cont(model.ctx, ggml_permute(model.ctx, image_features_patchview, 0, 2, 1, 3)); - /** - At the end of each row we have to add the row_end embeddings, which are the same as the newline embeddings - image_feature = torch.cat(( - image_feature, - self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device) - ), dim=-1) - * - */ - - // ggml_tensor_printf(permuted_cont,"permuted_cont",__LINE__,false,false); - struct ggml_tensor *flatten = ggml_view_2d(model.ctx, permuted_cont, clip_n_mmproj_embd(ctx_clip), num_patches_height * num_patches_width * num_patches_per_side * num_patches_per_side, size_ele * clip_n_mmproj_embd(ctx_clip), 0); - // ggml_tensor_printf(flatten,"flatten",__LINE__,false,false); - ggml_build_forward_expand(gf, flatten); - ggml_graph_compute_with_ctx(model.ctx, gf, 1); - struct ggml_tensor* result = ggml_graph_node(gf, -1); - - memcpy(image_embd_out, image_embd_v[0], clip_embd_nbytes(ctx_clip)); // main image as global context - // append without newline tokens (default behavior in llava_arch when not using unpad ): - memcpy(image_embd_out + clip_n_patches(ctx_clip) * clip_n_mmproj_embd(ctx_clip), (float*)result->data, clip_embd_nbytes(ctx_clip) * (num_images-1)); // grid patches - *n_img_pos_out = static_cast(result->ne[1]+clip_n_patches(ctx_clip)); - - // Debug: Test single segments - // Current findings: sending base image, sending a segment embedding all works similar to python - // However, permuted embeddings do not work yet (stride issue?) - // memcpy(image_embd_out, image_embd_v[0], clip_embd_nbytes(ctx_clip)); // main image as context - // memcpy(image_embd_out, (float*)prepared_cont->data, clip_embd_nbytes(ctx_clip)); // main image as context - // *n_img_pos_out=576; - - ggml_free(model.ctx); - return true; -} - -static clip_image_f32 * reshape_by_patch(clip_image_f32 * image, int patch_size) { - int width = image->nx; - int height = image->ny; - int num_patches = (height / patch_size) * (width / patch_size); - clip_image_f32 * patch = clip_image_f32_init(); - patch->nx = patch_size * num_patches; - patch->ny = patch_size; - patch->buf.resize(3 * patch->nx * patch->ny); - - int patch_index = 0; - - for (int i = 0; i < height; i += patch_size) { - for (int j = 0; j < width; j += patch_size) { - for (int pi = 0; pi < patch_size; ++pi) { - for (int pj = 0; pj < patch_size; ++pj) { - int input_index = ((i + pi) * width + (j + pj)) * 3; - int output_index = (pi * patch_size * num_patches + patch_index * patch_size + pj) * 3; - patch->buf[output_index] = image->buf[input_index]; - patch->buf[output_index+1] = image->buf[input_index+1]; - patch->buf[output_index+2] = image->buf[input_index+2]; - } - } - patch_index++; - } - } - return patch; -} - -static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos) { - // std::vector img_res_v; // format VectN x H x W x RGB (N x 336 x 336 x 3), so interleaved RGB - different to the python implementation which is N x 3 x 336 x 336 - clip_image_f32_batch_ptr img_res_v(clip_image_f32_batch_init()); - if (!clip_image_preprocess(ctx_clip, img, img_res_v.get())) { - LOG_ERR("%s: unable to preprocess image\n", __func__); - return false; - } - - const int64_t t_img_enc_start_us = ggml_time_us(); - - const char * mm_patch_merge_type = clip_patch_merge_type(ctx_clip); - - const size_t n_imgs = clip_image_f32_batch_n_images(img_res_v.get()); - - if (clip_is_minicpmv(ctx_clip) || clip_is_qwen2vl(ctx_clip)) { - std::vector image_embd_v; - image_embd_v.resize(n_imgs); - clip_image_size load_image_size; - - for (size_t i = 0; i < n_imgs; i++) { - const int64_t t_img_enc_step_start_us = ggml_time_us(); - int nx = clip_image_f32_batch_nx(img_res_v.get(), i); - int ny = clip_image_f32_batch_ny(img_res_v.get(), i); - image_embd_v[i] = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, nx, ny)); - int patch_size = 14; - load_image_size.width = nx; - load_image_size.height = ny; - clip_add_load_image_size(ctx_clip, &load_image_size); - - bool encoded = false; - clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i); - if (clip_is_qwen2vl(ctx_clip)) { - encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd_v[i]); - } - else { - encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(img_res, patch_size), image_embd_v[i]); - } - - if (!encoded) { - LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) n_imgs); - return false; - } - const int64_t t_img_enc_steop_batch_us = ggml_time_us(); - LOG_INF("%s: step %d of %d encoded in %8.2f ms\n", __func__, (int)i+1, (int)n_imgs, (t_img_enc_steop_batch_us - t_img_enc_step_start_us) / 1000.0); - } - const int64_t t_img_enc_batch_us = ggml_time_us(); - LOG_INF("%s: all %d segments encoded in %8.2f ms\n", __func__, (int)n_imgs, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); - - int n_img_pos_out = 0; - for (size_t i = 0; i < image_embd_v.size(); i++) { - int nx = clip_image_f32_batch_nx(img_res_v.get(), i); - int ny = clip_image_f32_batch_ny(img_res_v.get(), i); - clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i); - std::memcpy( - image_embd + n_img_pos_out * clip_n_mmproj_embd(ctx_clip), - image_embd_v[i], - clip_embd_nbytes_by_img(ctx_clip, nx, ny)); - n_img_pos_out += clip_n_patches_by_img(ctx_clip, img_res); - } - *n_img_pos = n_img_pos_out; - for (size_t i = 0; i < image_embd_v.size(); i++) { - free(image_embd_v[i]); - } - image_embd_v.clear(); - load_image_size.width = img->nx; - load_image_size.height = img->ny; - clip_add_load_image_size(ctx_clip, &load_image_size); - LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size.width, load_image_size.height); - } - else if (clip_is_glm(ctx_clip)){ - struct clip_image_size * load_image_size = clip_image_size_init(); - load_image_size->width = clip_image_f32_batch_nx(img_res_v.get(), 0); - load_image_size->height = clip_image_f32_batch_ny(img_res_v.get(), 0); - clip_add_load_image_size(ctx_clip, load_image_size); - - clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), 0); - bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd); - int pos = int(load_image_size->width/clip_get_patch_size(ctx_clip)/2); - *n_img_pos = (pos * pos + 2); - if (!encoded){ - LOG_ERR("Unable to encode image \n"); - return false; - } - } - else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) { - // flat / default llava-1.5 type embedding - *n_img_pos = clip_n_patches(ctx_clip); - clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), 0); - bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd); // image_embd shape is 576 x 4096 - if (!encoded) { - LOG_ERR("Unable to encode image\n"); - - return false; - } - } - else { - // spatial_unpad llava-1.6 type embedding - // TODO: CLIP needs batching support - in HF the llm projection is separate after encoding, which might be a solution to quickly get batching working - std::vector image_embd_v; - image_embd_v.resize(n_imgs); - for (size_t i = 0; i < n_imgs; i++) { - clip_image_f32 * img_res = clip_image_f32_get_img(img_res_v.get(), i); - image_embd_v[i] = (float *)malloc(clip_embd_nbytes(ctx_clip)); // 576 patches * 4096 embeddings * 4 bytes = 9437184 - const bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd_v[i]); // image data is in 3x336x336 format and will be converted to 336x336x3 inside - if (!encoded) { - LOG_ERR("Unable to encode image - spatial_unpad - subimage %d of %d\n", (int) i+1, (int) n_imgs); - return false; - } - } - const int64_t t_img_enc_batch_us = ggml_time_us(); - LOG_INF("%s: %d segments encoded in %8.2f ms\n", __func__, (int)n_imgs, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); - - const int32_t * image_grid = clip_image_grid(ctx_clip); - const size_t num_gridpoints = get_clip_image_grid_size(ctx_clip); - - std::vector> grid_pinpoints; - for (size_t i = 0; i < num_gridpoints; i += 2) { - grid_pinpoints.push_back({image_grid[i], image_grid[i+1]}); - } - - const int32_t image_size = clip_get_image_size(ctx_clip); - - struct clip_image_grid_shape grid_shape = get_anyres_image_grid_shape({img->nx,img->ny}, grid_pinpoints, image_size); - - int n_img_pos_out; - clip_llava_handle_patches(ctx_clip, image_embd_v, grid_shape, image_embd, &n_img_pos_out); - *n_img_pos = n_img_pos_out; - - for (size_t i = 0; i < image_embd_v.size(); i++) { - free(image_embd_v[i]); - } - image_embd_v.clear(); - - // debug image/segment/normalization content: - // clip_image_u8 * tmp = clip_image_u8_init(); - // clip_image_convert_f32_to_u8(*image_feature, *tmp); - // clip_image_save_to_bmp(*tmp, "image_feature.bmp"); - } - - LOG_INF("%s: image embedding created: %d tokens\n", __func__, *n_img_pos); - - const int64_t t_img_enc_end_us = ggml_time_us(); - float t_img_enc_ms = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0; - - LOG_INF("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, t_img_enc_ms, t_img_enc_ms / *n_img_pos); - - return true; -} - -bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip) { - // make sure that the correct mmproj was used, i.e., compare apples to apples - int n_llama_embd = llama_model_n_embd(llama_get_model(ctx_llama)); - auto n_image_embd = clip_n_mmproj_embd(ctx_clip); - if (n_image_embd != n_llama_embd) { - LOG_ERR("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd); - return false; - } - return true; -} - -bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) { - // Granite vision uses up to 10 patches + base patch - int num_max_patches = 11; - if (clip_is_minicpmv(ctx_clip)) { - num_max_patches = 10; - } - if (clip_is_glm(ctx_clip)) { - num_max_patches = 1; - } - float * image_embd; - if (clip_is_qwen2vl(ctx_clip)) { - // qwen2vl don't split image into chunks, so `num_max_patches` is not needed. - image_embd = (float *)malloc(clip_embd_nbytes_by_img(ctx_clip, img->nx, img->ny)); - } else { - image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)*num_max_patches); // TODO: base on gridsize/llava model - } - if (!image_embd) { - LOG_ERR("Unable to allocate memory for image embeddings\n"); - return false; - } - - int n_img_pos; - if (!encode_image_with_clip(ctx_clip, n_threads, img, image_embd, &n_img_pos)) { - LOG_ERR("%s: cannot encode image, aborting\n", __func__); - free(image_embd); - return false; - } - *image_embd_out = image_embd; - *n_img_pos_out = n_img_pos; - - return true; -} - -struct llava_embd_batch { - std::vector pos; - std::vector n_seq_id; - std::vector seq_id_0; - std::vector seq_ids; - std::vector logits; - llama_batch batch; - llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { - pos .resize(n_tokens); - n_seq_id.resize(n_tokens); - seq_ids .resize(n_tokens + 1); - logits .resize(n_tokens); - seq_id_0.resize(1); - seq_id_0[0] = seq_id; - seq_ids [n_tokens] = nullptr; - batch = { - /*n_tokens =*/ n_tokens, - /*tokens =*/ nullptr, - /*embd =*/ embd, - /*pos =*/ pos.data(), - /*n_seq_id =*/ n_seq_id.data(), - /*seq_id =*/ seq_ids.data(), - /*logits =*/ logits.data(), - }; - for (int i = 0; i < n_tokens; i++) { - batch.pos [i] = pos_0 + i; - batch.n_seq_id[i] = 1; - batch.seq_id [i] = seq_id_0.data(); - batch.logits [i] = false; - } - } -}; - -bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) { - int n_embd = llama_model_n_embd(llama_get_model(ctx_llama)); - - for (int i = 0; i < image_embed->n_image_pos; i += n_batch) { - int n_eval = image_embed->n_image_pos - i; - if (n_eval > n_batch) { - n_eval = n_batch; - } - float * embd = image_embed->embed+i*n_embd; - llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0); - if (llama_decode(ctx_llama, llava_batch.batch)) { - LOG_ERR("%s : failed to eval\n", __func__); - return false; - } - *n_past += n_eval; - } - return true; -} - -struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length) { - clip_image_u8 * img = clip_image_u8_init(); - if (!clip_image_load_from_bytes(image_bytes, image_bytes_length, img)) { - clip_image_u8_free(img); - LOG_ERR("%s: can't load image from bytes, is it a valid image?", __func__); - return NULL; - } - - float* image_embed = NULL; - int n_image_pos = 0; - bool image_embed_result = llava_image_embed_make_with_clip_img(ctx_clip, n_threads, img, &image_embed, &n_image_pos); - if (!image_embed_result) { - clip_image_u8_free(img); - LOG_ERR("%s: couldn't embed the image\n", __func__); - return NULL; - } - - clip_image_u8_free(img); - auto result = (llava_image_embed*)malloc(sizeof(llava_image_embed)); - result->embed = image_embed; - result->n_image_pos = n_image_pos; - return result; -} - -static bool load_file_to_bytes(const char* path, unsigned char** bytesOut, long *sizeOut) { - auto file = fopen(path, "rb"); - if (file == NULL) { - LOG_ERR("%s: can't read file %s\n", __func__, path); - return false; - } - - fseek(file, 0, SEEK_END); - auto fileSize = ftell(file); - fseek(file, 0, SEEK_SET); - - auto buffer = (unsigned char *)malloc(fileSize); // Allocate memory to hold the file data - if (buffer == NULL) { - LOG_ERR("%s: failed to alloc %ld bytes for file %s\n", __func__, fileSize, path); - perror("Memory allocation error"); - fclose(file); - return false; - } - errno = 0; - size_t ret = fread(buffer, 1, fileSize, file); // Read the file into the buffer - if (ferror(file)) { - LOG_ERR("read error: %s", strerror(errno)); - free(buffer); - fclose(file); - return false; - } - if (ret != (size_t) fileSize) { - LOG_ERR("unexpectedly reached end of file"); - free(buffer); - fclose(file); - return false; - } - fclose(file); // Close the file - - *bytesOut = buffer; - *sizeOut = fileSize; - return true; -} - -struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path) { - unsigned char* image_bytes; - long image_bytes_length; - auto loaded = load_file_to_bytes(image_path, &image_bytes, &image_bytes_length); - if (!loaded) { - LOG_ERR("%s: failed to load %s\n", __func__, image_path); - return NULL; - } - - llava_image_embed *embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, image_bytes, image_bytes_length); - free(image_bytes); - - return embed; -} - -void llava_image_embed_free(struct llava_image_embed * embed) { - free(embed->embed); - free(embed); -} diff --git a/examples/llava/llava.h b/examples/llava/llava.h deleted file mode 100644 index b6feb3027..000000000 --- a/examples/llava/llava.h +++ /dev/null @@ -1,49 +0,0 @@ -#ifndef LLAVA_H -#define LLAVA_H - -#include "ggml.h" - -#ifdef LLAMA_SHARED -# if defined(_WIN32) && !defined(__MINGW32__) -# ifdef LLAMA_BUILD -# define LLAVA_API __declspec(dllexport) -# else -# define LLAVA_API __declspec(dllimport) -# endif -# else -# define LLAVA_API __attribute__ ((visibility ("default"))) -# endif -#else -# define LLAVA_API -#endif - -#ifdef __cplusplus -extern "C" { -#endif - -struct clip_ctx; -struct llava_image_embed { - float * embed; - int n_image_pos; -}; - -/** sanity check for clip <-> llava embed size match */ -LLAVA_API bool llava_validate_embed_size(const struct llama_context * ctx_llama, const struct clip_ctx * ctx_clip); - -LLAVA_API bool llava_image_embed_make_with_clip_img(struct clip_ctx * ctx_clip, int n_threads, const struct clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out); - -/** build an image embed from image file bytes */ -LLAVA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length); -/** build an image embed from a path to an image filename */ -LLAVA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path); -/** free an embedding made with llava_image_embed_make_* */ -LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed); - -/** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */ -LLAVA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past); - -#ifdef __cplusplus -} -#endif - -#endif diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp deleted file mode 100644 index 5ad970c22..000000000 --- a/examples/llava/minicpmv-cli.cpp +++ /dev/null @@ -1,354 +0,0 @@ -#include "arg.h" -#include "log.h" -#include "common.h" -#include "sampling.h" -#include "clip.h" -#include "llava.h" -#include "llama.h" -#include "ggml.h" - -#include -#include -#include -#include -#include -#include // TODO: remove me - -struct llava_context { - struct clip_ctx * ctx_clip = NULL; - struct llama_context * ctx_llama = NULL; - struct llama_model * model = NULL; -}; - -static void show_additional_info(int /*argc*/, char ** argv) { - LOG("\nexample usage:\n\n%s -m --mmproj --image --image [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]); - LOG("\nnote: a lower temperature value like 0.1 is recommended for better quality.\n"); -} - -static struct llama_model * llava_init(common_params * params) { - llama_backend_init(); - llama_numa_init(params->numa); - - llama_model_params model_params = common_model_params_to_llama(*params); - - llama_model * model = llama_model_load_from_file(params->model.path.c_str(), model_params); - if (model == NULL) { - LOG_ERR("%s: unable to load model\n" , __func__); - return NULL; - } - return model; -} - -static struct llava_context * llava_init_context(common_params * params, llama_model * model) { - auto prompt = params->prompt; - if (prompt.empty()) { - prompt = "describe the image in detail."; - } - - llama_context_params ctx_params = common_context_params_to_llama(*params); - if (params->n_ctx < 2048) { - // warn user here, "Image processing requires at least 2048 context, setting context to 2048" - LOG_WRN("%s: Image processing requires at least 2048 context, setting context to 2048\n" , __func__); - ctx_params.n_ctx = 2048; - } else { - ctx_params.n_ctx = params->n_ctx; - } - - llama_context * ctx_llama = llama_init_from_model(model, ctx_params); - - if (ctx_llama == NULL) { - LOG_ERR("%s: failed to create the llama_context\n" , __func__); - return NULL; - } - - auto * ctx_llava = (struct llava_context *)malloc(sizeof(llava_context)); - - ctx_llava->ctx_llama = ctx_llama; - ctx_llava->model = model; - return ctx_llava; -} - -static void llava_free(struct llava_context * ctx_llava) { - if (ctx_llava->ctx_clip) { - clip_free(ctx_llava->ctx_clip); - ctx_llava->ctx_clip = NULL; - } - - llama_free(ctx_llava->ctx_llama); - llama_model_free(ctx_llava->model); - llama_backend_free(); -} - -static struct clip_ctx * clip_init_context(common_params * params) { - const char * clip_path = params->mmproj.path.c_str(); - - auto prompt = params->prompt; - if (prompt.empty()) { - prompt = "describe the image in detail."; - } - struct clip_context_params clip_params = { - /* use_gpu */ params->n_gpu_layers != 0, - /* verbosity */ GGML_LOG_LEVEL_INFO, // TODO: make this configurable - }; - auto * ctx_clip = clip_init(clip_path, clip_params); - return ctx_clip; -} - -static bool eval_tokens(struct llama_context * ctx_llama, std::vector tokens, int n_batch, int * n_past) { - int N = (int) tokens.size(); - for (int i = 0; i < N; i += n_batch) { - int n_eval = (int) tokens.size() - i; - if (n_eval > n_batch) { - n_eval = n_batch; - } - if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) { - LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); - return false; - } - *n_past += n_eval; - } - return true; -} - -static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past) { - std::vector tokens; - tokens.push_back(id); - return eval_tokens(ctx_llama, tokens, 1, n_past); -} - -static bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, bool add_bos){ - std::string str2 = str; - std::vector embd_inp = common_tokenize(ctx_llama, str2, add_bos, true); - return eval_tokens(ctx_llama, embd_inp, n_batch, n_past); -} - -static void process_eval_image_embed(struct llava_context * ctx_llava, const struct llava_image_embed * embeds, int n_batch, int * n_past, int idx) { - float * image_embed = (float *)malloc(clip_embd_nbytes(ctx_llava->ctx_clip)); - std::memcpy(image_embed, embeds->embed + idx * clip_n_patches(ctx_llava->ctx_clip) * clip_n_mmproj_embd(ctx_llava->ctx_clip), clip_embd_nbytes(ctx_llava->ctx_clip)); - - auto * slice_embed = (llava_image_embed*)malloc(sizeof(llava_image_embed)); - slice_embed->embed = image_embed; - slice_embed->n_image_pos = clip_n_patches(ctx_llava->ctx_clip); - llava_eval_image_embed(ctx_llava->ctx_llama, slice_embed, n_batch, n_past); - llava_image_embed_free(slice_embed); -} - -static void process_image(struct llava_context * ctx_llava, struct llava_image_embed * embeds, common_params * params, int &n_past) { - std::string system_prompt; - int idx = 0; - int num_image_embeds = embeds->n_image_pos / clip_n_patches(ctx_llava->ctx_clip); - int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip); - if (has_minicpmv_projector == 2) { - system_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"; - } - else if (has_minicpmv_projector == 3) { - system_prompt = "<|im_start|>user\n"; - } - else if (has_minicpmv_projector == 4) { - system_prompt = "<|im_start|>user\n"; - } - LOG_INF("%s: image token past: %d\n", __func__, n_past); - eval_string(ctx_llava->ctx_llama, (system_prompt+"").c_str(), params->n_batch, &n_past, false); - process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++); - eval_string(ctx_llava->ctx_llama, std::string("").c_str(), params->n_batch, &n_past, false); - if (num_image_embeds > 1) { - if (has_minicpmv_projector == 2) { - size_t num_image_embeds_col = clip_uhd_num_image_embeds_col(ctx_llava->ctx_clip); - eval_string(ctx_llava->ctx_llama, std::string("").c_str(), params->n_batch, &n_past, false); - for (size_t i = 0; i < (num_image_embeds-1)/num_image_embeds_col; ++i) { - for (size_t j = 0; j < num_image_embeds_col; ++j) { - eval_string(ctx_llava->ctx_llama, std::string("").c_str(), params->n_batch, &n_past, false); - process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++); - eval_string(ctx_llava->ctx_llama, std::string("").c_str(), params->n_batch, &n_past, false); - if (j == num_image_embeds_col - 1) { - eval_string(ctx_llava->ctx_llama, std::string("\n").c_str(), params->n_batch, &n_past, false); - } - } - } - eval_string(ctx_llava->ctx_llama, std::string("").c_str(), params->n_batch, &n_past, false); - } - else if (has_minicpmv_projector == 3 || has_minicpmv_projector == 4) { - size_t num_image_embeds_col = clip_uhd_num_image_embeds_col(ctx_llava->ctx_clip); - for (size_t i = 0; i < (num_image_embeds-1)/num_image_embeds_col; ++i) { - for (size_t j = 0; j < num_image_embeds_col; ++j) { - eval_string(ctx_llava->ctx_llama, std::string("").c_str(), params->n_batch, &n_past, false); - process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++); - eval_string(ctx_llava->ctx_llama, std::string("").c_str(), params->n_batch, &n_past, false); - if (j == num_image_embeds_col - 1) { - eval_string(ctx_llava->ctx_llama, std::string("\n").c_str(), params->n_batch, &n_past, false); - } - } - } - } - } - LOG_INF("%s: image token past: %d\n", __func__, n_past); -} - -static const char * sample(struct common_sampler * smpl, - struct llama_context * ctx_llama, - int * n_past) { - const llama_token id = common_sampler_sample(smpl, ctx_llama, -1); - common_sampler_accept(smpl, id, true); - - const llama_model * model = llama_get_model(ctx_llama); - const llama_vocab * vocab = llama_model_get_vocab(model); - - static std::string ret; - if (llama_vocab_is_eog(vocab, id)) { - ret = ""; - } else { - ret = common_token_to_piece(ctx_llama, id); - } - eval_id(ctx_llama, id, n_past); - return ret.c_str(); -} - -static struct llava_context * minicpmv_init(common_params * params, const std::string & fname, int &n_past){ - auto * ctx_clip = clip_init_context(params); - auto * embeds = llava_image_embed_make_with_filename(ctx_clip, params->cpuparams.n_threads, fname.c_str()); - if (!embeds) { - LOG_ERR("failed to load image %s. Terminating\n\n", fname.c_str()); - return NULL; - } - - // process the prompt - if (params->prompt.empty() && params->interactive == false) { - LOG_ERR("prompt should be given or interactive mode should be on"); - return NULL; - } - - auto * model = llava_init(params); - if (model == NULL) { - fprintf(stderr, "%s: error: failed to init minicpmv model\n", __func__); - return NULL; - } - const int64_t t_llava_init_start_us = ggml_time_us(); - auto * ctx_llava = llava_init_context(params, model); - ctx_llava->ctx_clip = ctx_clip; - const int64_t t_llava_init_end_us = ggml_time_us(); - float t_llava_init_ms = (t_llava_init_end_us - t_llava_init_start_us) / 1000.0; - LOG_INF("%s: llava init in %8.2f ms.\n", __func__, t_llava_init_ms); - - const int64_t t_process_image_start_us = ggml_time_us(); - process_image(ctx_llava, embeds, params, n_past); - const int64_t t_process_image_end_us = ggml_time_us(); - float t_process_image_ms = (t_process_image_end_us - t_process_image_start_us) / 1000.0; - LOG_INF("%s: llama process image in %8.2f ms.\n", __func__, t_process_image_ms); - - llava_image_embed_free(embeds); - return ctx_llava; -} - -static struct common_sampler * llama_init(struct llava_context * ctx_llava, common_params * params, const std::string & prompt, int & n_past, bool is_first = false){ - std::string user_prompt = prompt; - int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip); - if (!is_first) { - if (has_minicpmv_projector == 2) { - user_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + prompt; - } - else if (has_minicpmv_projector == 3) { - user_prompt = "<|im_start|>user\n" + prompt; - } - else if (has_minicpmv_projector == 4) { - user_prompt = "<|im_start|>user\n" + prompt; - } - } - - eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false); - if (has_minicpmv_projector == 2) { - eval_string(ctx_llava->ctx_llama, "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", params->n_batch, &n_past, false); - } - else if (has_minicpmv_projector == 3) { - eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false); - } - else if (has_minicpmv_projector == 4) { - eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false); - } - - // generate the response - - LOG_INF("\n"); - - struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling); - return smpl; -} - -static const char * llama_loop(struct llava_context * ctx_llava,struct common_sampler * smpl, int &n_past){ - - const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past); - return tmp; -} - -int main(int argc, char ** argv) { - ggml_time_init(); - - common_params params; - - if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, show_additional_info)) { - return 1; - } - - common_init(); - - if (params.mmproj.path.empty() || (params.image.empty())) { - show_additional_info(argc, argv); - return 1; - } - - for (auto & image : params.image) { - int n_past = 0; - auto * ctx_llava = minicpmv_init(¶ms, image, n_past); - - if (!params.prompt.empty()) { - LOG("%s\n", params.prompt.c_str()); - LOG(""); - auto * smpl = llama_init(ctx_llava, ¶ms, params.prompt, n_past, true); - const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict; - std::string response; - bool have_tmp = false; - for (int i = 0; i < max_tgt_len; i++) { - const auto * tmp = llama_loop(ctx_llava, smpl, n_past); - response += tmp; - if (strcmp(tmp, "") == 0){ - if (!have_tmp) { - continue; - } - break; - } - if (strstr(tmp, "###")) break; // Yi-VL behavior - have_tmp = true; - printf("%s", tmp); - if (strstr(response.c_str(), "")) break; // minicpm-v - - fflush(stdout); - } - common_sampler_free(smpl); - }else { - while (true) { - LOG(""); - std::string prompt; - std::getline(std::cin, prompt); - LOG(""); - auto * smpl = llama_init(ctx_llava, ¶ms, prompt, n_past, true); - const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict; - std::string response; - for (int i = 0; i < max_tgt_len; i++) { - const auto * tmp = llama_loop(ctx_llava, smpl, n_past); - response += tmp; - if (strcmp(tmp, "") == 0) break; - printf("%s", tmp);// mistral llava-1.6 - if (strstr(response.c_str(), "")) break; // minicpm-v - fflush(stdout); - } - common_sampler_free(smpl); - } - } - printf("\n"); - llama_perf_context_print(ctx_llava->ctx_llama); - - ctx_llava->model = NULL; - llava_free(ctx_llava); - } - - return 0; -} diff --git a/examples/llava/mtmd.cpp b/examples/llava/mtmd.cpp deleted file mode 100644 index 114c274bc..000000000 --- a/examples/llava/mtmd.cpp +++ /dev/null @@ -1,341 +0,0 @@ -#include "clip.h" -#include "clip-impl.h" -#include "mtmd.h" - -#include "llama.h" - -#include -#include -#include -#include -#include -#include -#include - -struct mtmd_context { - struct clip_ctx * ctx_clip; - const struct llama_model * text_model; - std::vector image_embd_v; // image embedding vector - bool print_timings; - int n_threads; - std::string image_marker; - - // TODO @ngxson : add timings - - mtmd_context(const char * mmproj_fname, - const llama_model * text_model, - const mtmd_context_params & ctx_params) : print_timings(ctx_params.print_timings), n_threads(ctx_params.n_threads), image_marker(ctx_params.image_marker) { - clip_context_params ctx_clip_params; - ctx_clip_params.use_gpu = ctx_params.use_gpu; - ctx_clip_params.verbosity = ctx_params.verbosity; - ctx_clip = clip_init(mmproj_fname, ctx_clip_params); - if (!ctx_clip) { - throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname)); - } - this->text_model = text_model; - } - - ~mtmd_context() { - clip_free(ctx_clip); - } -}; - -struct mtmd_image_tokens_data { - clip_image_f32_batch batch_f32; // preprocessed image patches -}; - -struct mtmd_image_tokens { - uint32_t nx; // number of tokens in x direction - uint32_t ny; // number of tokens in y direction - uint32_t n_tokens() const { return nx * ny; } - clip_image_f32_batch batch_f32; // preprocessed image patches -}; - -mtmd_context * mtmd_init_from_file(const char * mmproj_fname, - const struct llama_model * text_model, - const struct mtmd_context_params ctx_params) { - try { - return new mtmd_context(mmproj_fname, text_model, ctx_params); - } catch (const std::exception & e) { - LOG_ERR("%s: error: %s\n", __func__, e.what()); - return nullptr; - } -} - -void mtmd_free(mtmd_context * ctx) { - if (ctx) { - delete ctx; - } -} - -// copied from common_tokenize -static std::vector mtmd_tokenize_text_internal( - const struct llama_vocab * vocab, - const std::string & text, - bool add_special, - bool parse_special) { - // upper limit for the number of tokens - int n_tokens = text.length() + 2 * add_special; - std::vector result(n_tokens); - n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); - if (n_tokens < 0) { - result.resize(-n_tokens); - int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); - GGML_ASSERT(check == -n_tokens); - } else { - result.resize(n_tokens); - } - return result; -} - -mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx, - const mtmd_input_text & text, - const std::vector & bitmaps) { - mtmd_input_chunks * output = new mtmd_input_chunks; - auto vocab = llama_model_get_vocab(ctx->text_model); - - std::string prompt_modified(text.text); - std::string marker_modified(ctx->image_marker); - projector_type proj_type = clip_get_projector_type(ctx->ctx_clip); - // a bit hacky here, but works for now - // for some models, we need to add prefix and suffix to the image embeddings - if (proj_type == PROJECTOR_TYPE_GEMMA3) { - // ... (image embeddings) ... - marker_modified = "" + ctx->image_marker + ""; - string_replace_all(prompt_modified, ctx->image_marker, marker_modified); - } - - std::vector parts = string_split_str(text.text, ctx->image_marker); - output->clear(); - output->reserve(parts.size()); - - size_t i_img = 0; - - for (const auto & part : parts) { - //printf("tokenizing part: %s\n", part.c_str()); - bool add_bos = &parts.front() == ∂ - auto tokens = mtmd_tokenize_text_internal(vocab, part, text.add_special && add_bos, text.parse_special); - if (tokens.empty()) { - continue; - } - mtmd_input_chunk chunk{ - MTMD_INPUT_CHUNK_TYPE_TEXT, - std::move(tokens), - {}, - }; - output->emplace_back(std::move(chunk)); - - if (&parts.back() != &part) { - // add image token to middle of 2 parts - - if (i_img >= bitmaps.size()) { - LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size()); - return nullptr; - } - - // shim layer - clip_image_u8_ptr img_u8(clip_image_u8_init()); - img_u8->nx = bitmaps[i_img].nx; - img_u8->ny = bitmaps[i_img].ny; - img_u8->buf.resize(bitmaps[i_img].data.size()); - std::memcpy(img_u8->buf.data(), bitmaps[i_img].data.data(), img_u8->nx * img_u8->ny * 3); - - // preprocess image - clip_image_f32_batch batch_f32; - bool ok = clip_image_preprocess(ctx->ctx_clip, img_u8.get(), &batch_f32); - if (!ok) { - LOG_ERR("Unable to preprocess image\n"); - return nullptr; - } - - mtmd_image_tokens * image_tokens = new mtmd_image_tokens; - image_tokens->nx = clip_n_patches(ctx->ctx_clip); // TODO @ngxson : use clip_n_patches_by_image - image_tokens->ny = 1; // TODO - image_tokens->batch_f32 = std::move(batch_f32); - - mtmd_input_chunk chunk{ - MTMD_INPUT_CHUNK_TYPE_IMAGE, - {}, - image_tokens, - }; - output->emplace_back(std::move(chunk)); - i_img++; - } - } - - return output; -} - -void mtmd_input_chunks_free(mtmd_input_chunks * chunks) { - for (auto & chunk : *chunks) { - if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE && chunk.tokens_image) { - delete chunk.tokens_image; - } - } - delete chunks; -} - -int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) { - int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip); - ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd); - bool ok = clip_image_batch_encode( - ctx->ctx_clip, - ctx->n_threads, - &image_tokens->batch_f32, - ctx->image_embd_v.data()); - return ok ? 0 : 1; -} - -float * mtmd_get_output_embd(mtmd_context * ctx) { - return ctx->image_embd_v.data(); -} - -size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks) { - size_t n_tokens = 0; - for (auto & chunk : *chunks) { - if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) { - n_tokens += chunk.tokens_text.size(); - } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { - n_tokens += chunk.tokens_image->n_tokens(); - } else { - GGML_ASSERT(false && "chunk type not supported"); - } - } - return n_tokens; -} - -// helper struct to make working with embd batch easier -// note: this will be removed after llama_batch_ext refactoring -struct decode_embd_batch { - std::vector pos; - std::vector n_seq_id; - std::vector seq_id_0; - std::vector seq_ids; - std::vector logits; - llama_batch batch; - decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { - pos .resize(n_tokens); - n_seq_id.resize(n_tokens); - seq_ids .resize(n_tokens + 1); - logits .resize(n_tokens); - seq_id_0.resize(1); - seq_id_0[0] = seq_id; - seq_ids [n_tokens] = nullptr; - batch = { - /*n_tokens =*/ n_tokens, - /*tokens =*/ nullptr, - /*embd =*/ embd, - /*pos =*/ pos.data(), - /*n_seq_id =*/ n_seq_id.data(), - /*seq_id =*/ seq_ids.data(), - /*logits =*/ logits.data(), - }; - for (int i = 0; i < n_tokens; i++) { - batch.pos [i] = pos_0 + i; - batch.n_seq_id[i] = 1; - batch.seq_id [i] = seq_id_0.data(); - batch.logits [i] = false; - } - } -}; - -int32_t mtmd_helper_eval(mtmd_context * ctx, - llama_context * lctx, - mtmd_input_chunks * chunks, - llama_pos pos0, - llama_seq_id seq_id, - int32_t n_batch) { - int32_t ret; - llama_pos n_past = pos0; - llama_batch text_batch = llama_batch_init(n_batch, 0, 1); - - for (auto & chunk : *chunks) { - bool is_last = &chunk == &chunks->back(); - if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) { - // TODO @ngxson : may need to split into smaller batches - text_batch.n_tokens = chunk.tokens_text.size(); - for (size_t i = 0; i < chunk.tokens_text.size(); i++) { - text_batch.token [i] = chunk.tokens_text[i]; - text_batch.pos [i] = n_past++; - text_batch.n_seq_id[i] = 1; - text_batch.seq_id [i][0] = seq_id; - text_batch.logits [i] = false; - } - if (is_last) { - // always get logits for last input chunk - text_batch.logits[text_batch.n_tokens - 1] = true; - } - ret = llama_decode(lctx, text_batch); - if (ret != 0) { - LOG_ERR("failed to decode text\n"); - llama_batch_free(text_batch); - return ret; - } - - } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) { - GGML_ASSERT(!is_last && "logits for last image chunk is not yet support"); - GGML_ASSERT(chunk.tokens_image != nullptr); - int64_t t0 = ggml_time_ms(); - if (ctx->print_timings) { - LOG_INF("encoding image...\n"); - } - ret = mtmd_encode(ctx, chunk.tokens_image); - if (ret != 0) { - LOG_ERR("failed to encode image\n"); - llama_batch_free(text_batch); - return ret; - } - if (ctx->print_timings) { - LOG_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0); - } - - int32_t n_tokens = chunk.tokens_image->n_tokens(); - float * embd = mtmd_get_output_embd(ctx); - decode_embd_batch batch_img(embd, n_tokens, n_past, 0); - int64_t t1 = ggml_time_ms(); - ret = llama_decode(lctx, batch_img.batch); - if (ret != 0) { - LOG_ERR("failed to decode image\n"); - llama_batch_free(text_batch); - return ret; - } - if (ctx->print_timings) { - LOG_INF("image decoded in %" PRId64 " ms\n", ggml_time_ms() - t1); - } - - n_past += n_tokens; - - } else { - GGML_ASSERT(false && "chunk type not supported"); - } - } - - llama_batch_free(text_batch); - return 0; -} - -int32_t mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len, mtmd_bitmap & output) { - clip_image_u8_ptr img_u8(clip_image_u8_init()); - bool ok = clip_image_load_from_bytes(buf, len, img_u8.get()); - if (!ok) { - LOG_ERR("Unable to load image from buffer\n"); - return 1; - } - unsigned char * data = clip_image_u8_get_data(img_u8.get(), &output.nx, &output.ny); - output.data.resize(output.nx * output.ny * 3); - std::memcpy(output.data.data(), data, output.nx * output.ny * 3); - return 0; -} - -int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & output) { - clip_image_u8_ptr img_u8(clip_image_u8_init()); - bool ok = clip_image_load_from_file(fname, img_u8.get()); - if (!ok) { - LOG_ERR("Unable to load image %s\n", fname); - return 1; - } - unsigned char * data = clip_image_u8_get_data(img_u8.get(), &output.nx, &output.ny); - output.data.resize(output.nx * output.ny * 3); - std::memcpy(output.data.data(), data, output.nx * output.ny * 3); - return 0; -} diff --git a/examples/llava/mtmd.h b/examples/llava/mtmd.h deleted file mode 100644 index 598f6947b..000000000 --- a/examples/llava/mtmd.h +++ /dev/null @@ -1,146 +0,0 @@ -#ifndef MTMD_H -#define MTMD_H - -#include "ggml.h" -#include "llama.h" -#include "clip.h" - -#include -#include -#include - -#ifdef LLAMA_SHARED -# if defined(_WIN32) && !defined(__MINGW32__) -# ifdef LLAMA_BUILD -# define MTMD_API __declspec(dllexport) -# else -# define MTMD_API __declspec(dllimport) -# endif -# else -# define MTMD_API __attribute__ ((visibility ("default"))) -# endif -#else -# define MTMD_API -#endif - -#ifdef __cplusplus - -enum mtmd_input_chunk_type { - MTMD_INPUT_CHUNK_TYPE_TEXT, - MTMD_INPUT_CHUNK_TYPE_IMAGE, -}; - -struct mtmd_context; -struct mtmd_image_tokens; - -// represents raw image data, layout is RGBRGBRGB... -// length of data must be nx * ny * 3 -struct mtmd_bitmap { - uint32_t nx; - uint32_t ny; - std::vector data; -}; - -struct mtmd_input_chunk { - mtmd_input_chunk_type type; - std::vector tokens_text; - mtmd_image_tokens * tokens_image = nullptr; -}; - -using mtmd_input_chunks = std::vector; - -struct mtmd_context_params { - bool use_gpu = true; - bool print_timings = true; - int n_threads = 4; - enum ggml_log_level verbosity = GGML_LOG_LEVEL_INFO; - const char * image_marker = "<__image__>"; -}; - -struct mtmd_input_text { - std::string text; - bool add_special; - bool parse_special; -}; - -// initialize the mtmd context -// return nullptr on failure -MTMD_API mtmd_context * mtmd_init_from_file(const char * mmproj_fname, - const llama_model * text_model, - const mtmd_context_params ctx_params); - -MTMD_API void mtmd_free(mtmd_context * ctx); - -// tokenize an input text prompt and an image -// the prompt must have the input image marker (default: "<__image__>") in it -// the marker will be replaced with the image tokens -// for example: -// "here is an image: <__image__>\ndescribe it in detail." -// this will gives 3 chunks: -// 1. "here is an image: " -// 2. (image tokens) -// 3. "\ndescribe it in detail." -// number of bitmaps must be equal to the number of image markers in the prompt -// this function is thread-safe (shared ctx) -MTMD_API mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx, - const mtmd_input_text & text, - const std::vector & bitmaps); - -// free image chunk data -MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks); - -// returns 0 on success -MTMD_API int32_t mtmd_encode(mtmd_context * ctx, - const mtmd_image_tokens * image_tokens); - -// get output embeddings from the last encode pass -MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx); - -// -// helper functions (can be implemented based on other functions) -// - -// helper to count the total number of tokens from a list of chunks, useful to keep track of n_past -MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks); - -// helper function that automatically: -// 1. run llama_decode() on text chunks -// 2. run mtmd_encode() on image chunks, then mtmd_get_output_embd() and then llama_decode() -// if any of the mtmd_encode() or llama_decode() calls return non-zero, stop and forward the error -// otherwise, returns 0 on success -MTMD_API int32_t mtmd_helper_eval(mtmd_context * ctx, - llama_context * lctx, - mtmd_input_chunks * chunks, - llama_pos pos0, - llama_seq_id seq_id, - int32_t n_batch); - -// helper function to construct a mtmd_bitmap from a file -// returns 0 on success -// this function is thread-safe -MTMD_API int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & output); - -// helper function to construct a mtmd_bitmap from a buffer -// the buffer must be an image in format supported by stb_image (jpg, png, bmp, gif, etc.) -// returns 0 on success -// this function is thread-safe -MTMD_API int32_t mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len, mtmd_bitmap & output); - -// convenient unique_ptr wrappers -struct mtmd_context_deleter { - void operator()(mtmd_context * val) { mtmd_free(val); } -}; -using mtmd_context_ptr = std::unique_ptr; - -struct mtmd_input_chunks_deleter { - void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val); } -}; -using mtmd_input_chunks_ptr = std::unique_ptr; - -#else - -static_assert(false && "C header is not yet supported by this library"); - -#endif - -#endif diff --git a/examples/llava/qwen2_vl_surgery.py b/examples/llava/qwen2_vl_surgery.py deleted file mode 100644 index c87606b4f..000000000 --- a/examples/llava/qwen2_vl_surgery.py +++ /dev/null @@ -1,165 +0,0 @@ -import argparse -from typing import Dict - -import torch -import numpy as np -from gguf import * -from transformers import ( - Qwen2VLForConditionalGeneration, - Qwen2VLProcessor, - AutoProcessor, - Qwen2VLConfig -) - - -VISION = "clip.vision" - - -def k(raw_key: str, arch: str) -> str: - return raw_key.format(arch=arch) - - -def to_gguf_name(name: str) -> str: - og = name - name = name.replace("text_model", "t").replace("vision_model", "v") - name = name.replace("blocks", "blk").replace("embeddings.", "") - name = name.replace("attn.", "attn_") - name = name.replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("proj.", "out.") - # name = name.replace("layrnorm", "ln").replace("layer_norm", "ln").replace("layernorm", "ln") - name = name.replace("norm1", "ln1").replace("norm2", "ln2") - name = name.replace("merger.mlp", 'mm') - print(f"[to_gguf_name] {og} --> {name}") - return name - - -def find_vision_tensors(qwen2vl, dtype) -> Dict[str, np.ndarray]: - vision_model = qwen2vl.visual - tensor_map = {} - for name, ten in vision_model.state_dict().items(): - ten = ten.numpy() - if 'qkv' in name: - if ten.ndim == 2: # weight - c3, _ = ten.shape - else: # bias - c3 = ten.shape[0] - assert c3 % 3 == 0 - c = c3 // 3 - wq = ten[:c] - wk = ten[c: c * 2] - wv = ten[c * 2:] - tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "q")] = wq - tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "k")] = wk - tensor_map[to_gguf_name(f"vision_model.{name}").replace("qkv", "v")] = wv - elif 'merger' in name: - if name.endswith("ln_q.weight"): - tensor_map['v.post_ln.weight'] = ten - elif name.endswith("ln_q.bias"): - tensor_map['v.post_ln.bias'] = ten - else: - # "merger.mlp.%d.weight/bias" --> "mm.%d.weight/bias" - tensor_map[to_gguf_name(name)] = ten - elif 'patch_embed.proj.weight' in name: - # NOTE: split Conv3D into Conv2Ds - c1, c2, kt, kh, kw = ten.shape - assert kt == 2, "Current implmentation only support temporal_patch_size of 2" - tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...] - tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...] - else: - tensor_map[to_gguf_name(f"vision_model.{name}")] = ten - - for new_name, ten in tensor_map.items(): - if ten.ndim <= 1 or new_name.endswith("_norm.weight"): - tensor_map[new_name] = ten.astype(np.float32) - else: - tensor_map[new_name] = ten.astype(dtype) - tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32) # dummy tensor, just here as a placeholder - return tensor_map - - -def main(args): - if args.data_type == 'fp32': - dtype = torch.float32 - np_dtype = np.float32 - ftype = 0 - elif args.data_type == 'fp16': - dtype = torch.float32 - np_dtype = np.float16 - ftype = 1 - else: - raise ValueError() - - local_model = False - model_path = "" - model_name = args.model_name - print("model_name: ", model_name) - qwen2vl = Qwen2VLForConditionalGeneration.from_pretrained( - model_name, torch_dtype=dtype, device_map="cpu" - ) - cfg: Qwen2VLConfig = qwen2vl.config # type: ignore[reportAssignmentType] - vcfg = cfg.vision_config - - if os.path.isdir(model_name): - local_model = True - if model_name.endswith(os.sep): - model_name = model_name[:-1] - model_path = model_name - model_name = os.path.basename(model_name) - fname_out = f"{model_name.replace('/', '-').lower()}-vision.gguf" - - fout = GGUFWriter(path=fname_out, arch="clip") - fout.add_description("image encoder for Qwen2VL") - - fout.add_file_type(ftype) - fout.add_bool("clip.has_text_encoder", False) - fout.add_bool("clip.has_vision_encoder", True) - fout.add_bool("clip.has_qwen2vl_merger", True) - fout.add_string("clip.projector_type", "qwen2vl_merger") - - print(cfg.vision_config) - if 'silu' in cfg.vision_config.hidden_act.lower(): - fout.add_bool("clip.use_silu", True) - fout.add_bool("clip.use_gelu", False) - elif 'gelu' in cfg.vision_config.hidden_act.lower(): - fout.add_bool("clip.use_silu", False) - fout.add_bool("clip.use_gelu", 'quick' not in cfg.vision_config.hidden_act.lower()) - else: - raise ValueError() - - tensor_map = find_vision_tensors(qwen2vl, np_dtype) - for name, data in tensor_map.items(): - fout.add_tensor(name, data) - - fout.add_uint32("clip.vision.patch_size", vcfg.patch_size) - fout.add_uint32("clip.vision.image_size", 14 * 40) # some reasonable size that is divable by (14*2) - fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), vcfg.embed_dim) - fout.add_uint32("clip.vision.projection_dim", vcfg.hidden_size) - fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vcfg.num_heads) - fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) - fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), vcfg.depth) - fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), 0) # not sure what this does, put 0 here as a placeholder - fout.add_name(model_name) - """ - HACK: Since vision rope related parameter aren't stored in the `Qwen2VLConfig, - it will be hardcoded in the `clip_image_build_graph` from `clip.cpp`. - """ - - if local_model: - processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_path) - else: - processor: Qwen2VLProcessor = AutoProcessor.from_pretrained(model_name) - fout.add_array("clip.vision.image_mean", processor.image_processor.image_mean) # type: ignore[reportAttributeAccessIssue] - fout.add_array("clip.vision.image_std", processor.image_processor.image_std) # type: ignore[reportAttributeAccessIssue] - - fout.write_header_to_file() - fout.write_kv_data_to_file() - fout.write_tensors_to_file() - fout.close() - print("save model as: ", fname_out) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("model_name", nargs='?', default="Qwen/Qwen2-VL-2B-Instruct") - parser.add_argument("--data_type", nargs='?', choices=['fp32', 'fp16'], default="fp32") - args = parser.parse_args() - main(args) diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp deleted file mode 100644 index eca7b7f10..000000000 --- a/examples/llava/qwen2vl-cli.cpp +++ /dev/null @@ -1,584 +0,0 @@ -#include "arg.h" -#include "base64.hpp" -#include "log.h" -#include "common.h" -#include "sampling.h" -#include "clip.h" -#include "llava.h" -#include "llama.h" -#include "ggml.h" - -#ifdef GGML_USE_CUDA -#include "ggml-cuda.h" -#endif -#ifdef NDEBUG -#include "ggml-alloc.h" -#include "ggml-backend.h" -#endif - -#include -#include -#include -#include -#include -#include -#include - - -static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, - int n_batch, int * n_past, int * st_pos_id, struct clip_image_size * image_size) { - int n_embd = llama_model_n_embd(llama_get_model(ctx_llama)); - const int patch_size = 14 * 2; - const int ph = image_size->height / patch_size + (image_size->height % patch_size > 0); - const int pw = image_size->width / patch_size + (image_size->width % patch_size > 0); - auto img_tokens = image_embed->n_image_pos; - // llama_pos mrope_pos[img_tokens * 4]; - std::vector mrope_pos; - mrope_pos.resize(img_tokens * 4); - - for (int y = 0; y < ph; y++) - { - for (int x = 0; x < pw; x++) - { - int i = y * pw + x; - mrope_pos[i] = *st_pos_id; - mrope_pos[i + img_tokens] = *st_pos_id + y; - mrope_pos[i + img_tokens * 2] = *st_pos_id + x; - mrope_pos[i + img_tokens * 3] = 0; - } - } - *st_pos_id += std::max(pw, ph); - - int processed = 0; - std::vector batch_mrope_pos; - batch_mrope_pos.resize(img_tokens * 4); - - for (int i = 0; i < img_tokens; i += n_batch) { - int n_eval = img_tokens - i; - if (n_eval > n_batch) { - n_eval = n_batch; - } - - // llama_pos batch_mrope_pos[n_eval * 4]; - std::fill(batch_mrope_pos.begin(), batch_mrope_pos.end(), 0); - memcpy(batch_mrope_pos.data(), &mrope_pos[processed], n_eval * sizeof(llama_pos)); - memcpy(&batch_mrope_pos[n_eval * 1], &mrope_pos[img_tokens * 1 + processed], n_eval * sizeof(llama_pos)); - memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos)); - memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos)); - - llama_batch batch = { - int32_t(n_eval), // n_tokens - nullptr, // token - (image_embed->embed+i*n_embd), // embed - batch_mrope_pos.data(), // pos - nullptr, // n_seq_id - nullptr, // seq_id - nullptr, // logits - }; - - if (llama_decode(ctx_llama, batch)) { - LOG_ERR("%s : failed to eval\n", __func__); - return false; - } - *n_past += n_eval; - processed += n_eval; - } - return true; -} - - -static bool eval_tokens(struct llama_context * ctx_llama, std::vector tokens, int n_batch, int * n_past, int * st_pos_id) { - int N = (int) tokens.size(); - std::vector pos; - for (int i = 0; i < N; i += n_batch) { - int n_eval = (int) tokens.size() - i; - if (n_eval > n_batch) { - n_eval = n_batch; - } - auto batch = llama_batch_get_one(&tokens[i], n_eval); - // TODO: add mrope pos ids somewhere else - pos.resize(batch.n_tokens * 4); - std::fill(pos.begin(), pos.end(), 0); - for (int j = 0; j < batch.n_tokens * 3; j ++) { - pos[j] = *st_pos_id + (j % batch.n_tokens); - } - batch.pos = pos.data(); - - if (llama_decode(ctx_llama, batch)) { - LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); - return false; - } - *n_past += n_eval; - *st_pos_id += n_eval; - } - return true; -} - -static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past, int * st_pos_id) { - std::vector tokens; - tokens.push_back(id); - return eval_tokens(ctx_llama, tokens, 1, n_past, st_pos_id); -} - -static bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, int * st_pos_id, bool add_bos){ - std::string str2 = str; - std::vector embd_inp = common_tokenize(ctx_llama, str2, add_bos, true); - eval_tokens(ctx_llama, embd_inp, n_batch, n_past, st_pos_id); - return true; -} - -static const char * sample(struct common_sampler * smpl, - struct llama_context * ctx_llama, - int * n_past, int * st_pos_id) { - const llama_token id = common_sampler_sample(smpl, ctx_llama, -1); - common_sampler_accept(smpl, id, true); - - const llama_model * model = llama_get_model(ctx_llama); - const llama_vocab * vocab = llama_model_get_vocab(model); - - static std::string ret; - if (llama_vocab_is_eog(vocab, id)) { - ret = ""; - } else { - ret = common_token_to_piece(ctx_llama, id); - } - eval_id(ctx_llama, id, n_past, st_pos_id); - return ret.c_str(); -} - -static const char* IMG_BASE64_TAG_BEGIN = ""; - -static void find_image_tag_in_prompt(const std::string& prompt, size_t& begin_out, size_t& end_out) { - begin_out = prompt.find(IMG_BASE64_TAG_BEGIN); - end_out = prompt.find(IMG_BASE64_TAG_END, (begin_out == std::string::npos) ? 0UL : begin_out); -} - -static bool prompt_contains_image(const std::string& prompt) { - size_t begin, end; - find_image_tag_in_prompt(prompt, begin, end); - return (begin != std::string::npos); -} - -// replaces the base64 image tag in the prompt with `replacement` -static llava_image_embed * llava_image_embed_make_with_prompt_base64(struct clip_ctx * ctx_clip, int n_threads, const std::string& prompt) { - size_t img_base64_str_start, img_base64_str_end; - find_image_tag_in_prompt(prompt, img_base64_str_start, img_base64_str_end); - if (img_base64_str_start == std::string::npos || img_base64_str_end == std::string::npos) { - LOG_ERR("%s: invalid base64 image tag. must be %s%s\n", __func__, IMG_BASE64_TAG_BEGIN, IMG_BASE64_TAG_END); - return NULL; - } - - auto base64_bytes_start = img_base64_str_start + strlen(IMG_BASE64_TAG_BEGIN); - auto base64_bytes_count = img_base64_str_end - base64_bytes_start; - auto base64_str = prompt.substr(base64_bytes_start, base64_bytes_count ); - - auto required_bytes = base64::required_encode_size(base64_str.size()); - auto img_bytes = std::vector(required_bytes); - base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin()); - - auto embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, img_bytes.data(), img_bytes.size()); - if (!embed) { - LOG_ERR("%s: could not load image from base64 string.\n", __func__); - return NULL; - } - - return embed; -} - -static std::string remove_image_from_prompt(const std::string& prompt, const char * replacement = "") { - size_t begin, end; - find_image_tag_in_prompt(prompt, begin, end); - if (begin == std::string::npos || end == std::string::npos) { - return prompt; - } - auto pre = prompt.substr(0, begin); - auto post = prompt.substr(end + strlen(IMG_BASE64_TAG_END)); - return pre + replacement + post; -} - -struct llava_context { - struct clip_ctx * ctx_clip = NULL; - struct llama_context * ctx_llama = NULL; - struct llama_model * model = NULL; -}; - -static void print_usage(int, char ** argv) { - LOG("\n example usage:\n"); - LOG("\n %s -m --mmproj --image --image [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]); - LOG("\n note: a lower temperature value like 0.1 is recommended for better quality.\n"); -} - -static struct llava_image_embed * load_image(llava_context * ctx_llava, common_params * params, const std::string & fname) { - - // load and preprocess the image - llava_image_embed * embed = NULL; - auto prompt = params->prompt; - if (prompt_contains_image(prompt)) { - if (!params->image.empty()) { - LOG_INF("using base64 encoded image instead of command line image path\n"); - } - embed = llava_image_embed_make_with_prompt_base64(ctx_llava->ctx_clip, params->cpuparams.n_threads, prompt); - if (!embed) { - LOG_ERR("%s: can't load image from prompt\n", __func__); - return NULL; - } - params->prompt = remove_image_from_prompt(prompt); - } else { - embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->cpuparams.n_threads, fname.c_str()); - if (!embed) { - fprintf(stderr, "%s: is %s really an image file?\n", __func__, fname.c_str()); - return NULL; - } - } - - return embed; -} - -static void process_prompt(struct llava_context * ctx_llava, struct llava_image_embed * image_embed, common_params * params, const std::string & prompt) { - int n_past = 0; - int cur_pos_id = 0; - - const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict; - - std::string system_prompt, user_prompt; - size_t image_pos = prompt.find("<|vision_start|>"); - if (image_pos != std::string::npos) { - // new templating mode: Provide the full prompt including system message and use as a placeholder for the image - system_prompt = prompt.substr(0, image_pos); - user_prompt = prompt.substr(image_pos + std::string("<|vision_pad|>").length()); - LOG_INF("system_prompt: %s\n", system_prompt.c_str()); - if (params->verbose_prompt) { - auto tmp = common_tokenize(ctx_llava->ctx_llama, system_prompt, true, true); - for (int i = 0; i < (int) tmp.size(); i++) { - LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str()); - } - } - LOG_INF("user_prompt: %s\n", user_prompt.c_str()); - if (params->verbose_prompt) { - auto tmp = common_tokenize(ctx_llava->ctx_llama, user_prompt, true, true); - for (int i = 0; i < (int) tmp.size(); i++) { - LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str()); - } - } - } else { - // llava-1.5 native mode - system_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|>"; - user_prompt = "<|vision_end|>" + prompt + "<|im_end|>\n<|im_start|>assistant\n"; - if (params->verbose_prompt) { - auto tmp = common_tokenize(ctx_llava->ctx_llama, user_prompt, true, true); - for (int i = 0; i < (int) tmp.size(); i++) { - LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str()); - } - } - } - - eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, &cur_pos_id, true); - if (image_embed != nullptr) { - auto image_size = clip_get_load_image_size(ctx_llava->ctx_clip); - qwen2vl_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past, &cur_pos_id, image_size); - } - eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, &cur_pos_id, false); - - // generate the response - - LOG("\n"); - - struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling); - if (!smpl) { - LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__); - exit(1); - } - - std::string response = ""; - for (int i = 0; i < max_tgt_len; i++) { - const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past, &cur_pos_id); - response += tmp; - if (strcmp(tmp, "") == 0) break; - if (strstr(tmp, "###")) break; // Yi-VL behavior - LOG("%s", tmp); - if (strstr(response.c_str(), "<|im_end|>")) break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works) - if (strstr(response.c_str(), "<|im_start|>")) break; // Yi-34B llava-1.6 - if (strstr(response.c_str(), "USER:")) break; // mistral llava-1.6 - - fflush(stdout); - } - - common_sampler_free(smpl); - LOG("\n"); -} - -static struct llama_model * llava_init(common_params * params) { - llama_backend_init(); - llama_numa_init(params->numa); - - llama_model_params model_params = common_model_params_to_llama(*params); - - llama_model * model = llama_model_load_from_file(params->model.path.c_str(), model_params); - if (model == NULL) { - LOG_ERR("%s: unable to load model\n" , __func__); - return NULL; - } - return model; -} - -static struct llava_context * llava_init_context(common_params * params, llama_model * model) { - const char * clip_path = params->mmproj.path.c_str(); - - auto prompt = params->prompt; - if (prompt.empty()) { - prompt = "describe the image in detail."; - } - - auto ctx_clip = clip_model_load(clip_path, GGML_LOG_LEVEL_INFO); - - llama_context_params ctx_params = common_context_params_to_llama(*params); - ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings - - llama_context * ctx_llama = llama_init_from_model(model, ctx_params); - - if (ctx_llama == NULL) { - LOG_ERR("%s: failed to create the llama_context\n" , __func__); - return NULL; - } - - auto * ctx_llava = (struct llava_context *)malloc(sizeof(llava_context)); - - ctx_llava->ctx_llama = ctx_llama; - ctx_llava->ctx_clip = ctx_clip; - ctx_llava->model = model; - return ctx_llava; -} - -static void llava_free(struct llava_context * ctx_llava) { - if (ctx_llava->ctx_clip) { - clip_free(ctx_llava->ctx_clip); - ctx_llava->ctx_clip = NULL; - } - - llama_free(ctx_llava->ctx_llama); - llama_model_free(ctx_llava->model); - llama_backend_free(); -} - -#ifndef NDEBUG - -static void debug_test_mrope_2d() { - // 1. Initialize backend - ggml_backend_t backend = NULL; - std::string backend_name = ""; -#ifdef GGML_USE_CUDA - fprintf(stderr, "%s: using CUDA backend\n", __func__); - backend = ggml_backend_cuda_init(0); // init device 0 - backend_name = "cuda"; - if (!backend) { - fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); - } -#endif - // if there aren't GPU Backends fallback to CPU backend - if (!backend) { - backend = ggml_backend_cpu_init(); - backend_name = "cpu"; - } - - // Calculate the size needed to allocate - size_t ctx_size = 0; - ctx_size += 2 * ggml_tensor_overhead(); // tensors - // no need to allocate anything else! - - // 2. Allocate `ggml_context` to store tensor data - struct ggml_init_params params = { - /*.mem_size =*/ ctx_size, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_backend_alloc_ctx_tensors() - }; - struct ggml_context * ctx = ggml_init(params); - - struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 128, 12, 30); - ggml_set_name(inp_raw, "inp_raw"); - ggml_set_input(inp_raw); - - struct ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 30 * 4); - ggml_set_name(pos, "pos"); - ggml_set_input(pos); - - std::vector dummy_q; - dummy_q.resize(128 * 12 * 30); - std::fill(dummy_q.begin(), dummy_q.end(), 0.1); - // memcpy(inp_raw->data, dummy_q.data(), 128 * 12 * 30 * ggml_element_size(inp_raw)); - - std::vector pos_id; - pos_id.resize(30 * 4); - for (int i = 0; i < 30; i ++) { - pos_id[i] = i; - pos_id[i + 30] = i + 10; - pos_id[i + 60] = i + 20; - pos_id[i + 90] = i + 30; - } - int sections[4] = {32, 32, 0, 0}; - - // 4. Allocate a `ggml_backend_buffer` to store all tensors - ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); - - // 5. Copy tensor data from main memory (RAM) to backend buffer - ggml_backend_tensor_set(inp_raw, dummy_q.data(), 0, ggml_nbytes(inp_raw)); - ggml_backend_tensor_set(pos, pos_id.data(), 0, ggml_nbytes(pos)); - - // 6. Create a `ggml_cgraph` for mul_mat operation - struct ggml_cgraph * gf = NULL; - struct ggml_context * ctx_cgraph = NULL; - - // create a temporally context to build the graph - struct ggml_init_params params0 = { - /*.mem_size =*/ ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(), - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() - }; - ctx_cgraph = ggml_init(params0); - gf = ggml_new_graph(ctx_cgraph); - - struct ggml_tensor * result0 = ggml_rope_multi( - ctx_cgraph, inp_raw, pos, nullptr, - 128/2, sections, LLAMA_ROPE_TYPE_VISION, 32768, 1000000, 1, - 0, 1, 32, 1); - - // Add "result" tensor and all of its dependencies to the cgraph - ggml_build_forward_expand(gf, result0); - - // 7. Create a `ggml_gallocr` for cgraph computation - ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); - ggml_gallocr_alloc_graph(allocr, gf); - - // 9. Run the computation - int n_threads = 1; // Optional: number of threads to perform some operations with multi-threading - if (ggml_backend_is_cpu(backend)) { - ggml_backend_cpu_set_n_threads(backend, n_threads); - } - ggml_backend_graph_compute(backend, gf); - - // 10. Retrieve results (output tensors) - // in this example, output tensor is always the last tensor in the graph - struct ggml_tensor * result = result0; - // struct ggml_tensor * result = gf->nodes[gf->n_nodes - 1]; - float * result_data = (float *)malloc(ggml_nbytes(result)); - // because the tensor data is stored in device buffer, we need to copy it back to RAM - ggml_backend_tensor_get(result, result_data, 0, ggml_nbytes(result)); - const std::string bin_file = "mrope_2d_" + backend_name +".bin"; - std::ofstream outFile(bin_file, std::ios::binary); - - if (outFile.is_open()) { - outFile.write(reinterpret_cast(result_data), ggml_nbytes(result)); - outFile.close(); - std::cout << "Data successfully written to " + bin_file << std::endl; - } else { - std::cerr << "Error opening file!" << std::endl; - } - - free(result_data); - // 11. Free memory and exit - ggml_free(ctx_cgraph); - ggml_gallocr_free(allocr); - ggml_free(ctx); - ggml_backend_buffer_free(buffer); - ggml_backend_free(backend); -} - -static void debug_dump_img_embed(struct llava_context * ctx_llava) { - int n_embd = llama_model_n_embd(llama_get_model(ctx_llava->ctx_llama)); - int ne = n_embd * 4; - float vals[56 * 56 * 3]; - // float embd[ne]; - std::vector embd; - embd.resize(ne); - - for (int i = 0; i < 56*56; i++) - { - for (int c = 0; c < 3; c++) - vals[i * 3 + c] = (float)(i % (56 * 56)) / (56*56); - } - - clip_encode_float_image(ctx_llava->ctx_clip, 16, vals, 56, 56, embd.data()); - - std::ofstream outFile("img_embed.bin", std::ios::binary); - if (outFile.is_open()) { - outFile.write(reinterpret_cast(embd.data()), ne * sizeof(float)); - - outFile.close(); - std::cout << "Data successfully written to mrope.bin" << std::endl; - } else { - std::cerr << "Error opening file!" << std::endl; - } -} - -#endif - - -int main(int argc, char ** argv) { - ggml_time_init(); - - common_params params; - - if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, print_usage)) { - return 1; - } - - common_init(); - - if (params.mmproj.path.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) { - print_usage(argc, argv); - return 1; - } - - auto * model = llava_init(¶ms); - if (model == NULL) { - fprintf(stderr, "%s: error: failed to init llava model\n", __func__); - return 1; - } - - if (prompt_contains_image(params.prompt)) { - auto * ctx_llava = llava_init_context(¶ms, model); - - auto * image_embed = load_image(ctx_llava, ¶ms, ""); - - // process the prompt - process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); - - llama_perf_context_print(ctx_llava->ctx_llama); - llava_image_embed_free(image_embed); - ctx_llava->model = NULL; - llava_free(ctx_llava); -#ifndef NDEBUG - } else if (params.image[0].empty()) { - auto ctx_llava = llava_init_context(¶ms, model); - - debug_test_mrope_2d(); - debug_dump_img_embed(ctx_llava); - - llama_perf_context_print(ctx_llava->ctx_llama); - ctx_llava->model = NULL; - llava_free(ctx_llava); -#endif - } else { - for (auto & image : params.image) { - auto * ctx_llava = llava_init_context(¶ms, model); - - auto * image_embed = load_image(ctx_llava, ¶ms, image); - if (!image_embed) { - LOG_ERR("%s: failed to load image %s. Terminating\n\n", __func__, image.c_str()); - return 1; - } - - // process the prompt - process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); - - llama_perf_context_print(ctx_llava->ctx_llama); - llava_image_embed_free(image_embed); - ctx_llava->model = NULL; - llava_free(ctx_llava); - } - } - - llama_model_free(model); - - return 0; -} diff --git a/examples/llava/tests.sh b/examples/llava/tests.sh deleted file mode 100755 index cc9bda876..000000000 --- a/examples/llava/tests.sh +++ /dev/null @@ -1,81 +0,0 @@ -#!/bin/bash - -# make sure we are in the right directory -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) -cd $SCRIPT_DIR - -#export LLAMA_CACHE="$SCRIPT_DIR/tmp" - -set -eux - -mkdir -p $SCRIPT_DIR/output - -PROJ_ROOT="$SCRIPT_DIR/../.." -cd $PROJ_ROOT - -############### - -arr_bin=() -arr_hf=() - -add_test() { - local bin=$1 - local hf=$2 - arr_bin+=("$bin") - arr_hf+=("$hf") -} - -add_test "llama-gemma3-cli" "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M" -add_test "llama-llava-cli" "cmp-nct/Yi-VL-6B-GGUF:Q5_K" -add_test "llama-llava-cli" "guinmoon/MobileVLM-3B-GGUF:Q4_K_M" -add_test "llama-llava-cli" "THUDM/glm-edge-v-5b-gguf:Q4_K_M" -add_test "llama-llava-cli" "second-state/Llava-v1.5-7B-GGUF:Q2_K" -add_test "llama-llava-cli" "cjpais/llava-1.6-mistral-7b-gguf:Q3_K" -add_test "llama-llava-cli" "ibm-research/granite-vision-3.2-2b-GGUF:Q4_K_M" -add_test "llama-minicpmv-cli" "second-state/MiniCPM-Llama3-V-2_5-GGUF:Q2_K" # model from openbmb is corrupted -add_test "llama-minicpmv-cli" "openbmb/MiniCPM-V-2_6-gguf:Q2_K" -add_test "llama-minicpmv-cli" "openbmb/MiniCPM-o-2_6-gguf:Q4_0" -add_test "llama-qwen2vl-cli" "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M" - -############### - -cmake --build build -j --target "${arr_bin[@]}" - -arr_res=() - -for i in "${!arr_bin[@]}"; do - bin="${arr_bin[$i]}" - hf="${arr_hf[$i]}" - - echo "Running test with binary: $bin and HF model: $hf" - echo "" - echo "" - - output=$("$PROJ_ROOT/build/bin/$bin" -hf "$hf" --image $SCRIPT_DIR/test-1.jpeg -p "what is the publisher name of the newspaper?" --temp 0 2>&1 | tee /dev/tty) - - echo "$output" > $SCRIPT_DIR/output/$bin-$(echo "$hf" | tr '/' '-').log - - if echo "$output" | grep -iq "new york"; then - result="\033[32mOK\033[0m: $bin $hf" - else - result="\033[31mFAIL\033[0m: $bin $hf" - fi - echo -e "$result" - arr_res+=("$result") - - echo "" - echo "" - echo "" - echo "#################################################" - echo "#################################################" - echo "" - echo "" -done - -set +x - -for i in "${!arr_res[@]}"; do - echo -e "${arr_res[$i]}" -done -echo "" -echo "Output logs are saved in $SCRIPT_DIR/output" diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 7df20aee1..1e26d8221 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -50,8 +50,6 @@ int main(int argc, char ** argv) { const int N = 5; // n-gram size const int G = 15; // max verification n-grams - const bool dump_kv_cache = params.dump_kv_cache; - // init llama.cpp llama_backend_init(); llama_numa_init(params.numa); @@ -62,6 +60,8 @@ int main(int argc, char ** argv) { llama_model * model = llama_init.model.get(); llama_context * ctx = llama_init.context.get(); + auto * mem = llama_get_memory(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); // Tokenize the prompt @@ -96,7 +96,7 @@ int main(int argc, char ** argv) { llama_decode(ctx, llama_batch_get_one(&inp.back(), 1)); for (int s = 1; s < W + G + 1; ++s) { - llama_kv_self_seq_cp(ctx, 0, s, -1, -1); + llama_memory_seq_cp(mem, 0, s, -1, -1); } const auto t_enc_end = ggml_time_us(); @@ -152,9 +152,6 @@ int main(int argc, char ** argv) { // here we keep adding new n-grams as we go ngram_container ngrams_observed(llama_vocab_n_tokens(vocab), N, G); - // debug - struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, W + G + 1); - const auto t_dec_start = ggml_time_us(); // sample first token @@ -172,12 +169,6 @@ int main(int argc, char ** argv) { } while (true) { - // debug - if (dump_kv_cache) { - llama_kv_cache_view_update(ctx, &kvc_view); - common_kv_cache_dump_view_seqs(kvc_view, 40); - } - // build the mask from https://lmsys.org/blog/2023-11-21-lookahead-decoding/ // // Example for W = 5, N = 4, G = 2: @@ -438,17 +429,17 @@ int main(int argc, char ** argv) { // KV cache management // if no verification token matched, we simply remove all cells from this batch -> no fragmentation - llama_kv_self_seq_rm(ctx, -1, n_past, -1); + llama_memory_seq_rm(mem, -1, n_past, -1); if (seq_id_best != 0) { // if a verification token matched, we keep the best sequence and remove the rest // this leads to some KV cache fragmentation - llama_kv_self_seq_keep(ctx, seq_id_best); - llama_kv_self_seq_cp (ctx, seq_id_best, 0, -1, -1); - llama_kv_self_seq_rm (ctx, seq_id_best, -1, -1); + llama_memory_seq_keep(mem, seq_id_best); + llama_memory_seq_cp (mem, seq_id_best, 0, -1, -1); + llama_memory_seq_rm (mem, seq_id_best, -1, -1); for (int s = 1; s < W + G + 1; ++s) { - llama_kv_self_seq_cp(ctx, 0, s, -1, -1); + llama_memory_seq_cp(mem, 0, s, -1, -1); } } } @@ -473,8 +464,6 @@ int main(int argc, char ** argv) { common_sampler_free(smpl); - llama_kv_cache_view_free(&kvc_view); - llama_batch_free(batch); llama_backend_free(); diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 4ae93b2a5..2bfa26b55 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -24,8 +24,6 @@ int main(int argc, char ** argv){ // max. number of additional tokens to draft if match is found const int n_draft = params.speculative.n_max; - const bool dump_kv_cache = params.dump_kv_cache; - // init llama.cpp llama_backend_init(); llama_numa_init(params.numa); @@ -110,18 +108,9 @@ int main(int argc, char ** argv){ llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1); - // debug - struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1); - const auto t_dec_start = ggml_time_us(); while (true) { - // debug - if (dump_kv_cache) { - llama_kv_cache_view_update(ctx, &kvc_view); - common_kv_cache_dump_view_seqs(kvc_view, 40); - } - // print current draft sequence LOG_DBG("drafted %s\n", string_from(ctx, draft).c_str()); @@ -192,7 +181,7 @@ int main(int argc, char ** argv){ // KV cache management // clean the cache of draft tokens that weren't accepted - llama_kv_self_seq_rm(ctx, 0, n_past, -1); + llama_memory_seq_rm(llama_get_memory(ctx), 0, n_past, -1); common_batch_clear(batch_tgt); common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); diff --git a/examples/parallel/README.md b/examples/parallel/README.md index df0456733..2468a30d2 100644 --- a/examples/parallel/README.md +++ b/examples/parallel/README.md @@ -1,3 +1,14 @@ # llama.cpp/example/parallel Simplified simulation of serving incoming requests in parallel + +## Example + +Generate 128 client requests (`-ns 128`), simulating 8 concurrent clients (`-np 8`). The system prompt is shared (`-pps`), meaning that it is computed once at the start. The client requests consist of up to 10 junk questions (`--junk 10`) followed by the actual question. + +```bash +llama-parallel -m model.gguf -np 8 -ns 128 --top-k 1 -pps --junk 10 -c 16384 +``` + +> [!NOTE] +> It's recommended to use base models with this example. Instruction tuned models might not be able to properly follow the custom chat template specified here, so the results might not be as expected. diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 80698518e..d53e089a4 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -34,11 +34,61 @@ static std::string k_system = R"(Transcript of a never ending dialog, where the User interacts with an Assistant. The Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision. -User: Recommend a nice restaurant in the area. -Assistant: I recommend the restaurant "The Golden Duck". It is a 5 star restaurant with a great view of the city. The food is delicious and the service is excellent. The prices are reasonable and the portions are generous. The restaurant is located at 123 Main Street, New York, NY 10001. The phone number is (212) 555-1234. The hours are Monday through Friday from 11:00 am to 10:00 pm. The restaurant is closed on Saturdays and Sundays. -User: Who is Richard Feynman? -Assistant: Richard Feynman was an American physicist who is best known for his work in quantum mechanics and particle physics. He was awarded the Nobel Prize in Physics in 1965 for his contributions to the development of quantum electrodynamics. He was a popular lecturer and author, and he wrote several books, including "Surely You're Joking, Mr. Feynman!" and "What Do You Care What Other People Think?". -User:)"; +User: +Recommend a nice restaurant in the area. +Assistant: +I recommend the restaurant "The Golden Duck". It is a 5 star restaurant with a great view of the city. The food is delicious and the service is excellent. The prices are reasonable and the portions are generous. The restaurant is located at 123 Main Street, New York, NY 10001. The phone number is (212) 555-1234. The hours are Monday through Friday from 11:00 am to 10:00 pm. The restaurant is closed on Saturdays and Sundays. +User: +Who is Richard Feynman? +Assistant: +Richard Feynman was an American physicist who is best known for his work in quantum mechanics and particle physics. He was awarded the Nobel Prize in Physics in 1965 for his contributions to the development of quantum electrodynamics. He was a popular lecturer and author, and he wrote several books, including "Surely You're Joking, Mr. Feynman!" and "What Do You Care What Other People Think?". +)"; + +static std::vector k_questions = { + "What is the tallest mountain in the world?", + "Who was the first person to win two Nobel Prizes?", + "Which country invented paper?", + "What organ is primarily responsible for pumping blood throughout the body?", + "Which planet is known for its prominent ring system?", + "Who directed the movie 'Inception'?", + "What is the freezing point of water in Fahrenheit?", + "Which animal is known to have the longest lifespan?", + "What language has the most native speakers worldwide?", + "What is the capital city of Canada?", + "Who is credited with inventing the World Wide Web?", + "Which metal is liquid at room temperature?", + "What is the term for an animal that eats both plants and meat?", + "Who painted 'The Starry Night'?", + "What gas do humans exhale that plants use for photosynthesis?", + "What year did World War II end?", + "Which continent has the most countries?", + "Who wrote the novel 'Frankenstein'?", + "What does DNA stand for?", + "What is the main ingredient in traditional Japanese miso soup?" +}; + +static std::vector k_answers = { + "The tallest mountain in the world is Mount Everest.", + "Marie Curie was the first person to win two Nobel Prizes.", + "Paper was invented in China.", + "The heart is the organ responsible for pumping blood.", + "Saturn is known for its prominent ring system.", + "Christopher Nolan directed the movie 'Inception'.", + "The freezing point of water in Fahrenheit is 32°F.", + "The bowhead whale is known to have the longest lifespan among mammals.", + "Mandarin Chinese has the most native speakers in the world.", + "The capital city of Canada is Ottawa.", + "Tim Berners-Lee is credited with inventing the World Wide Web.", + "Mercury is the metal that is liquid at room temperature.", + "An animal that eats both plants and meat is called an omnivore.", + "'The Starry Night' was painted by Vincent van Gogh.", + "Humans exhale carbon dioxide, which plants use in photosynthesis.", + "World War II ended in 1945.", + "Africa is the continent with the most countries.", + "The novel 'Frankenstein' was written by Mary Shelley.", + "DNA stands for Deoxyribonucleic Acid.", + "The main ingredient in traditional Japanese miso soup is fermented soybean paste." +}; static std::vector k_prompts = { "What is the meaning of life?", @@ -49,7 +99,7 @@ static std::vector k_prompts = { "What is the best way to learn a new language?", "How to get a job at Google?", "If you could have any superpower, what would it be?", - "I want to learn how to play the piano.", + "I want to learn how to play the piano. What would be the best way to do it?", }; struct client { @@ -68,6 +118,7 @@ struct client { int64_t t_start_prompt; int64_t t_start_gen; + int32_t n_past = 0; int32_t n_prompt = 0; int32_t n_decoded = 0; int32_t i_batch = -1; @@ -107,6 +158,7 @@ int main(int argc, char ** argv) { common_params params; params.n_predict = 128; + params.n_junk = 1; if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) { return 1; @@ -126,7 +178,11 @@ int main(int argc, char ** argv) { // insert new requests as soon as the previous one is done const bool cont_batching = params.cont_batching; - const bool dump_kv_cache = params.dump_kv_cache; + // is the system prompt shared in the cache + const bool is_sp_shared = params.is_pp_shared; + + // extra text to insert in each client's prompt in order to make it larger + const int32_t n_junk = std::max(1, params.n_junk); // init llama.cpp llama_backend_init(); @@ -138,6 +194,8 @@ int main(int argc, char ** argv) { llama_model * model = llama_init.model.get(); llama_context * ctx = llama_init.context.get(); + auto * mem = llama_get_memory(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); // load the prompts from an external file if there are any @@ -169,6 +227,7 @@ int main(int argc, char ** argv) { } std::vector tokens_system; + tokens_system = common_tokenize(ctx, k_system, true); const int32_t n_tokens_system = tokens_system.size(); @@ -182,15 +241,13 @@ int main(int argc, char ** argv) { int32_t n_total_gen = 0; int32_t n_cache_miss = 0; - struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, n_clients); - const auto t_main_start = ggml_time_us(); LOG_INF("%s: Simulating parallel requests from clients:\n", __func__); LOG_INF("%s: n_parallel = %d, n_sequences = %d, cont_batching = %d, system tokens = %d\n", __func__, n_clients, n_seq, cont_batching, n_tokens_system); LOG_INF("\n"); - { + if (is_sp_shared) { LOG_INF("%s: Evaluating the system prompt ...\n", __func__); for (int32_t i = 0; i < n_tokens_system; ++i) { @@ -204,7 +261,7 @@ int main(int argc, char ** argv) { // assign the system KV cache to all parallel sequences for (int32_t i = 1; i <= n_clients; ++i) { - llama_kv_self_seq_cp(ctx, 0, i, -1, -1); + llama_memory_seq_cp(mem, 0, i, -1, -1); } LOG_INF("\n"); @@ -213,11 +270,6 @@ int main(int argc, char ** argv) { LOG_INF("Processing requests ...\n\n"); while (true) { - if (dump_kv_cache) { - llama_kv_cache_view_update(ctx, &kvc_view); - common_kv_cache_dump_view_seqs(kvc_view, 40); - } - common_batch_clear(batch); // decode any currently ongoing sequences @@ -228,7 +280,7 @@ int main(int argc, char ** argv) { client.i_batch = batch.n_tokens; - common_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true); + common_batch_add(batch, client.sampled, client.n_past++, { client.id + 1 }, true); client.n_decoded += 1; } @@ -236,9 +288,9 @@ int main(int argc, char ** argv) { if (batch.n_tokens == 0) { // all sequences have ended - clear the entire KV cache for (int i = 1; i <= n_clients; ++i) { - llama_kv_self_seq_rm(ctx, i, -1, -1); + llama_memory_seq_rm(mem, i, -1, -1); // but keep the system prompt - llama_kv_self_seq_cp(ctx, 0, i, -1, -1); + llama_memory_seq_cp(mem, 0, i, -1, -1); } LOG_INF("%s: clearing the KV cache\n", __func__); @@ -254,9 +306,26 @@ int main(int argc, char ** argv) { client.t_start_gen = 0; client.input = k_prompts[rand() % k_prompts.size()]; - client.prompt = client.input + "\nAssistant:"; client.response = ""; + // construct the prompt: + // [system prompt] + [junk] + [user prompt] + client.n_past = 0; + client.prompt = ""; + if (is_sp_shared) { + client.n_past = n_tokens_system; + } else { + client.prompt += k_system; + } + + const int n_junk_cur = rand() % n_junk; + + for (int i = 0; i < n_junk_cur; ++i) { + const int r = rand() % k_questions.size(); + client.prompt += "User:\n" + k_questions[r] + "\nAssistant:\n " + k_answers[r] + "\n"; + } + client.prompt += "User:\n" + client.input + "\nAssistant:\n"; + common_sampler_reset(client.smpl); // do not prepend BOS because we have a system prompt! @@ -264,7 +333,7 @@ int main(int argc, char ** argv) { tokens_prompt = common_tokenize(ctx, client.prompt, false); for (size_t i = 0; i < tokens_prompt.size(); ++i) { - common_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false); + common_batch_add(batch, tokens_prompt[i], client.n_past++, { client.id + 1 }, false); } // extract the logits only for the last token @@ -276,7 +345,7 @@ int main(int argc, char ** argv) { client.n_decoded = 0; client.i_batch = batch.n_tokens - 1; - LOG_INF("\033[31mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id); + LOG_INF("\033[31mClient %3d, seq %4d, junk = %4d, started decoding ...\033[0m\n", client.id, client.seq_id, n_junk_cur); g_seq_id += 1; @@ -295,7 +364,9 @@ int main(int argc, char ** argv) { // process in chunks of params.n_batch int32_t n_batch = params.n_batch; - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { + int32_t i_next = 0; + + for (int32_t i = 0; i < batch.n_tokens; i = i_next) { // experiment: process in powers of 2 //if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) { // n_batch /= 2; @@ -303,7 +374,7 @@ int main(int argc, char ** argv) { // continue; //} - const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); llama_batch batch_view = { n_tokens, @@ -323,19 +394,24 @@ int main(int argc, char ** argv) { return 1; } - LOG_ERR("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2); + LOG_WRN("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2); n_cache_miss += 1; // retry with half the batch size to try to find a free slot in the KV cache n_batch /= 2; - i -= n_batch; continue; } LOG_DBG("%s : decoded batch of %d tokens\n", __func__, n_tokens); + // move the head of the batch forward with the number of tokens we just processed + i_next = i + n_tokens; + + // on successful decode, restore the original batch size + n_batch = params.n_batch; + for (auto & client : clients) { if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) { continue; @@ -363,10 +439,9 @@ int main(int argc, char ** argv) { // client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str()); if (client.n_decoded > 2 && - (llama_vocab_is_eog(vocab, id) || - (params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) || - client.response.find("User:") != std::string::npos || - client.response.find('\n') != std::string::npos)) { + (llama_vocab_is_eog(vocab, id) || + (params.n_predict > 0 && client.n_decoded >= params.n_predict) || + client.response.find("User:") != std::string::npos)) { // basic reverse prompt const size_t pos = client.response.find("User:"); if (pos != std::string::npos) { @@ -374,8 +449,8 @@ int main(int argc, char ** argv) { } // delete only the generated part of the sequence, i.e. keep the system prompt in the cache - llama_kv_self_seq_rm(ctx, client.id + 1, -1, -1); - llama_kv_self_seq_cp(ctx, 0, client.id + 1, -1, -1); + llama_memory_seq_rm(mem, client.id + 1, -1, -1); + llama_memory_seq_cp(mem, 0, client.id + 1, -1, -1); const auto t_main_end = ggml_time_us(); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 347ea4a69..8a4faa383 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -126,6 +126,8 @@ int main(int argc, char ** argv) { int n_past = 0; + auto * mem = llama_get_memory(ctx); + // fill the KV cache for (int i = 0; i < n_ctx; i += n_batch) { if (i > 0 && n_grp > 1) { @@ -133,11 +135,10 @@ int main(int argc, char ** argv) { const int ib = i/n_batch - 1; const int bd = n_batch_grp*(n_grp - 1); - llama_kv_self_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd); - llama_kv_self_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); - llama_kv_self_update (ctx); + llama_memory_seq_add(mem, 0, n_past - n_batch, n_past, ib*bd); + llama_memory_seq_div(mem, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); - n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1; + n_past = llama_memory_seq_pos_max(mem, 0) + 1; } common_batch_clear(batch); @@ -167,12 +168,10 @@ int main(int argc, char ** argv) { LOG_INF("%s: shifting KV cache with %d\n", __func__, n_discard); - llama_kv_self_seq_rm (ctx, 0, n_keep , n_keep + n_discard); - llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); - //llama_kv_self_defrag (ctx); - llama_kv_self_update (ctx); + llama_memory_seq_rm (mem, 0, n_keep , n_keep + n_discard); + llama_memory_seq_add(mem, 0, n_keep + n_discard, n_ctx, -n_discard); - n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1; + n_past = llama_memory_seq_pos_max(mem, 0) + 1; common_batch_clear(batch); @@ -198,12 +197,10 @@ int main(int argc, char ** argv) { if (n_discard > 0) { LOG_INF("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard); - llama_kv_self_seq_rm (ctx, 0, n_keep , n_keep + n_discard); - llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); - //llama_kv_self_defrag (ctx); - llama_kv_self_update (ctx); + llama_memory_seq_rm (mem, 0, n_keep , n_keep + n_discard); + llama_memory_seq_add(mem, 0, n_keep + n_discard, n_ctx, -n_discard); - n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1; + n_past = llama_memory_seq_pos_max(mem, 0) + 1; } } diff --git a/examples/pydantic_models_to_grammar_examples.py b/examples/pydantic_models_to_grammar_examples.py index f94b82ca4..6dadb7f3f 100755 --- a/examples/pydantic_models_to_grammar_examples.py +++ b/examples/pydantic_models_to_grammar_examples.py @@ -23,7 +23,7 @@ def create_completion(host, prompt, gbnf_grammar): """Calls the /completion API on llama-server. See - https://github.com/ggml-org/llama.cpp/tree/HEAD/examples/server#api-endpoints + https://github.com/ggml-org/llama.cpp/tree/HEAD/tools/server#api-endpoints """ print(f" Request:\n Grammar:\n{textwrap.indent(gbnf_grammar, ' ')}\n Prompt:\n{textwrap.indent(prompt.rstrip(), ' ')}") headers = {"Content-Type": "application/json"} diff --git a/examples/quantize-stats/CMakeLists.txt b/examples/quantize-stats/CMakeLists.txt deleted file mode 100644 index 9a3a0d3cd..000000000 --- a/examples/quantize-stats/CMakeLists.txt +++ /dev/null @@ -1,6 +0,0 @@ -set(TARGET llama-quantize-stats) -add_executable(${TARGET} quantize-stats.cpp) -install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE llama build_info ${CMAKE_THREAD_LIBS_INIT}) -target_include_directories(${TARGET} PRIVATE ../../common) -target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 0efe20d4b..042e12c2b 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -81,14 +81,14 @@ static void batch_add_seq(llama_batch & batch, const std::vector & toke } } -static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { +static void batch_process(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { // clear previous kv_cache values (irrelevant for embeddings) - llama_kv_self_clear(ctx); + llama_memory_clear(llama_get_memory(ctx), false); // run model LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); if (llama_decode(ctx, batch) < 0) { - LOG_ERR("%s : failed to decode\n", __func__); + LOG_ERR("%s : failed to process\n", __func__); } for (int i = 0; i < batch.n_tokens; i++) { @@ -233,7 +233,7 @@ int main(int argc, char ** argv) { // encode if at capacity if (batch.n_tokens + n_toks > n_batch) { float * out = emb + p * n_embd; - batch_decode(ctx, batch, out, s, n_embd); + batch_process(ctx, batch, out, s, n_embd); common_batch_clear(batch); p += s; s = 0; @@ -246,7 +246,7 @@ int main(int argc, char ** argv) { // final batch float * out = emb + p * n_embd; - batch_decode(ctx, batch, out, s, n_embd); + batch_process(ctx, batch, out, s, n_embd); // save embeddings to chunks for (int i = 0; i < n_chunks; i++) { @@ -267,7 +267,7 @@ int main(int argc, char ** argv) { batch_add_seq(query_batch, query_tokens, 0); std::vector query_emb(n_embd, 0); - batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd); + batch_process(ctx, query_batch, query_emb.data(), 1, n_embd); common_batch_clear(query_batch); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 760ebbbf0..db79588f1 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -196,7 +196,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy); // erase whole kv - llama_kv_self_clear(ctx3); + llama_memory_clear(llama_get_memory(ctx3), true); fprintf(stderr, "%s : kv cache cleared\n", __func__); // restore kv into seq 1 diff --git a/examples/server/public/index.html.gz b/examples/server/public/index.html.gz deleted file mode 100644 index 674e22757..000000000 Binary files a/examples/server/public/index.html.gz and /dev/null differ diff --git a/examples/server/webui/src/components/ChatScreen.tsx b/examples/server/webui/src/components/ChatScreen.tsx deleted file mode 100644 index 29ab5ea64..000000000 --- a/examples/server/webui/src/components/ChatScreen.tsx +++ /dev/null @@ -1,296 +0,0 @@ -import { useEffect, useMemo, useState } from 'react'; -import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context'; -import ChatMessage from './ChatMessage'; -import { CanvasType, Message, PendingMessage } from '../utils/types'; -import { classNames, cleanCurrentUrl, throttle } from '../utils/misc'; -import CanvasPyInterpreter from './CanvasPyInterpreter'; -import StorageUtils from '../utils/storage'; -import { useVSCodeContext } from '../utils/llama-vscode'; -import { useChatTextarea, ChatTextareaApi } from './useChatTextarea.ts'; - -/** - * A message display is a message node with additional information for rendering. - * For example, siblings of the message node are stored as their last node (aka leaf node). - */ -export interface MessageDisplay { - msg: Message | PendingMessage; - siblingLeafNodeIds: Message['id'][]; - siblingCurrIdx: number; - isPending?: boolean; -} - -/** - * If the current URL contains "?m=...", prefill the message input with the value. - * If the current URL contains "?q=...", prefill and SEND the message. - */ -const prefilledMsg = { - content() { - const url = new URL(window.location.href); - return url.searchParams.get('m') ?? url.searchParams.get('q') ?? ''; - }, - shouldSend() { - const url = new URL(window.location.href); - return url.searchParams.has('q'); - }, - clear() { - cleanCurrentUrl(['m', 'q']); - }, -}; - -function getListMessageDisplay( - msgs: Readonly, - leafNodeId: Message['id'] -): MessageDisplay[] { - const currNodes = StorageUtils.filterByLeafNodeId(msgs, leafNodeId, true); - const res: MessageDisplay[] = []; - const nodeMap = new Map(); - for (const msg of msgs) { - nodeMap.set(msg.id, msg); - } - // find leaf node from a message node - const findLeafNode = (msgId: Message['id']): Message['id'] => { - let currNode: Message | undefined = nodeMap.get(msgId); - while (currNode) { - if (currNode.children.length === 0) break; - currNode = nodeMap.get(currNode.children.at(-1) ?? -1); - } - return currNode?.id ?? -1; - }; - // traverse the current nodes - for (const msg of currNodes) { - const parentNode = nodeMap.get(msg.parent ?? -1); - if (!parentNode) continue; - const siblings = parentNode.children; - if (msg.type !== 'root') { - res.push({ - msg, - siblingLeafNodeIds: siblings.map(findLeafNode), - siblingCurrIdx: siblings.indexOf(msg.id), - }); - } - } - return res; -} - -const scrollToBottom = throttle( - (requiresNearBottom: boolean, delay: number = 80) => { - const mainScrollElem = document.getElementById('main-scroll'); - if (!mainScrollElem) return; - const spaceToBottom = - mainScrollElem.scrollHeight - - mainScrollElem.scrollTop - - mainScrollElem.clientHeight; - if (!requiresNearBottom || spaceToBottom < 50) { - setTimeout( - () => mainScrollElem.scrollTo({ top: mainScrollElem.scrollHeight }), - delay - ); - } - }, - 80 -); - -export default function ChatScreen() { - const { - viewingChat, - sendMessage, - isGenerating, - stopGenerating, - pendingMessages, - canvasData, - replaceMessageAndGenerate, - } = useAppContext(); - - const textarea: ChatTextareaApi = useChatTextarea(prefilledMsg.content()); - - const { extraContext, clearExtraContext } = useVSCodeContext(textarea); - // TODO: improve this when we have "upload file" feature - const currExtra: Message['extra'] = extraContext ? [extraContext] : undefined; - - // keep track of leaf node for rendering - const [currNodeId, setCurrNodeId] = useState(-1); - const messages: MessageDisplay[] = useMemo(() => { - if (!viewingChat) return []; - else return getListMessageDisplay(viewingChat.messages, currNodeId); - }, [currNodeId, viewingChat]); - - const currConvId = viewingChat?.conv.id ?? null; - const pendingMsg: PendingMessage | undefined = - pendingMessages[currConvId ?? '']; - - useEffect(() => { - // reset to latest node when conversation changes - setCurrNodeId(-1); - // scroll to bottom when conversation changes - scrollToBottom(false, 1); - }, [currConvId]); - - const onChunk: CallbackGeneratedChunk = (currLeafNodeId?: Message['id']) => { - if (currLeafNodeId) { - setCurrNodeId(currLeafNodeId); - } - scrollToBottom(true); - }; - - const sendNewMessage = async () => { - const lastInpMsg = textarea.value(); - if (lastInpMsg.trim().length === 0 || isGenerating(currConvId ?? '')) - return; - textarea.setValue(''); - scrollToBottom(false); - setCurrNodeId(-1); - // get the last message node - const lastMsgNodeId = messages.at(-1)?.msg.id ?? null; - if ( - !(await sendMessage( - currConvId, - lastMsgNodeId, - lastInpMsg, - currExtra, - onChunk - )) - ) { - // restore the input message if failed - textarea.setValue(lastInpMsg); - } - // OK - clearExtraContext(); - }; - - const handleEditMessage = async (msg: Message, content: string) => { - if (!viewingChat) return; - setCurrNodeId(msg.id); - scrollToBottom(false); - await replaceMessageAndGenerate( - viewingChat.conv.id, - msg.parent, - content, - msg.extra, - onChunk - ); - setCurrNodeId(-1); - scrollToBottom(false); - }; - - const handleRegenerateMessage = async (msg: Message) => { - if (!viewingChat) return; - setCurrNodeId(msg.parent); - scrollToBottom(false); - await replaceMessageAndGenerate( - viewingChat.conv.id, - msg.parent, - null, - msg.extra, - onChunk - ); - setCurrNodeId(-1); - scrollToBottom(false); - }; - - const hasCanvas = !!canvasData; - - useEffect(() => { - if (prefilledMsg.shouldSend()) { - // send the prefilled message if needed - sendNewMessage(); - } else { - // otherwise, focus on the input - textarea.focus(); - } - prefilledMsg.clear(); - // no need to keep track of sendNewMessage - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [textarea.ref]); - - // due to some timing issues of StorageUtils.appendMsg(), we need to make sure the pendingMsg is not duplicated upon rendering (i.e. appears once in the saved conversation and once in the pendingMsg) - const pendingMsgDisplay: MessageDisplay[] = - pendingMsg && messages.at(-1)?.msg.id !== pendingMsg.id - ? [ - { - msg: pendingMsg, - siblingLeafNodeIds: [], - siblingCurrIdx: 0, - isPending: true, - }, - ] - : []; - - return ( -
-
- {/* chat messages */} -
-
- {/* placeholder to shift the message to the bottom */} - {viewingChat ? '' : 'Send a message to start'} -
- {[...messages, ...pendingMsgDisplay].map((msg) => ( - - ))} -
- - {/* chat input */} -
- - - {isGenerating(currConvId ?? '') ? ( - - ) : ( - - )} -
-
-
- {canvasData?.type === CanvasType.PY_INTERPRETER && ( - - )} -
-
- ); -} diff --git a/examples/server/webui/src/components/Header.tsx b/examples/server/webui/src/components/Header.tsx deleted file mode 100644 index 4c6b291e6..000000000 --- a/examples/server/webui/src/components/Header.tsx +++ /dev/null @@ -1,178 +0,0 @@ -import { useEffect, useState } from 'react'; -import StorageUtils from '../utils/storage'; -import { useAppContext } from '../utils/app.context'; -import { classNames } from '../utils/misc'; -import daisyuiThemes from 'daisyui/theme/object'; -import { THEMES } from '../Config'; -import { useNavigate } from 'react-router'; - -export default function Header() { - const navigate = useNavigate(); - const [selectedTheme, setSelectedTheme] = useState(StorageUtils.getTheme()); - const { setShowSettings } = useAppContext(); - - const setTheme = (theme: string) => { - StorageUtils.setTheme(theme); - setSelectedTheme(theme); - }; - - useEffect(() => { - document.body.setAttribute('data-theme', selectedTheme); - document.body.setAttribute( - 'data-color-scheme', - daisyuiThemes[selectedTheme]?.['color-scheme'] ?? 'auto' - ); - }, [selectedTheme]); - - const { isGenerating, viewingChat } = useAppContext(); - const isCurrConvGenerating = isGenerating(viewingChat?.conv.id ?? ''); - - const removeConversation = () => { - if (isCurrConvGenerating || !viewingChat) return; - const convId = viewingChat?.conv.id; - if (window.confirm('Are you sure to delete this conversation?')) { - StorageUtils.remove(convId); - navigate('/'); - } - }; - - const downloadConversation = () => { - if (isCurrConvGenerating || !viewingChat) return; - const convId = viewingChat?.conv.id; - const conversationJson = JSON.stringify(viewingChat, null, 2); - const blob = new Blob([conversationJson], { type: 'application/json' }); - const url = URL.createObjectURL(blob); - const a = document.createElement('a'); - a.href = url; - a.download = `conversation_${convId}.json`; - document.body.appendChild(a); - a.click(); - document.body.removeChild(a); - URL.revokeObjectURL(url); - }; - - return ( -
- {/* open sidebar button */} - - -
llama.cpp
- - {/* action buttons (top right) */} -
- {viewingChat && ( -
- {/* "..." button */} - - {/* dropdown menu */} - -
- )} - -
- -
- - {/* theme controller is copied from https://daisyui.com/components/theme-controller/ */} -
-
-
- - - -
-
    -
  • - -
  • - {THEMES.map((theme) => ( -
  • - e.target.checked && setTheme(theme)} - /> -
  • - ))} -
-
-
-
-
- ); -} diff --git a/examples/server/webui/src/components/Sidebar.tsx b/examples/server/webui/src/components/Sidebar.tsx deleted file mode 100644 index 34727c623..000000000 --- a/examples/server/webui/src/components/Sidebar.tsx +++ /dev/null @@ -1,96 +0,0 @@ -import { useEffect, useState } from 'react'; -import { classNames } from '../utils/misc'; -import { Conversation } from '../utils/types'; -import StorageUtils from '../utils/storage'; -import { useNavigate, useParams } from 'react-router'; - -export default function Sidebar() { - const params = useParams(); - const navigate = useNavigate(); - - const [conversations, setConversations] = useState([]); - const [currConv, setCurrConv] = useState(null); - - useEffect(() => { - StorageUtils.getOneConversation(params.convId ?? '').then(setCurrConv); - }, [params.convId]); - - useEffect(() => { - const handleConversationChange = async () => { - setConversations(await StorageUtils.getAllConversations()); - }; - StorageUtils.onConversationChanged(handleConversationChange); - handleConversationChange(); - return () => { - StorageUtils.offConversationChanged(handleConversationChange); - }; - }, []); - - return ( - <> - - -
- -
-
-

Conversations

- - {/* close sidebar button */} - -
- - {/* list of conversations */} -
navigate('/')} - > - + New conversation -
- {conversations.map((conv) => ( -
navigate(`/chat/${conv.id}`)} - dir="auto" - > - {conv.name} -
- ))} -
- Conversations are saved to browser's IndexedDB -
-
-
- - ); -} diff --git a/examples/server/webui/src/utils/common.tsx b/examples/server/webui/src/utils/common.tsx deleted file mode 100644 index 09b08b5c9..000000000 --- a/examples/server/webui/src/utils/common.tsx +++ /dev/null @@ -1,38 +0,0 @@ -export const XCloseButton: React.ElementType< - React.ClassAttributes & - React.HTMLAttributes -> = ({ className, ...props }) => ( - -); - -export const OpenInNewTab = ({ - href, - children, -}: { - href: string; - children: string; -}) => ( - - {children} - -); diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index 84f415973..2aee0a919 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -98,7 +98,7 @@ int main(int argc, char ** argv) { auto generate = [&](const std::string & prompt) { std::string response; - const bool is_first = llama_kv_self_used_cells(ctx) == 0; + const bool is_first = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) == 0; // tokenize the prompt const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true); @@ -113,7 +113,7 @@ int main(int argc, char ** argv) { while (true) { // check if we have enough space in the context to evaluate this batch int n_ctx = llama_n_ctx(ctx); - int n_ctx_used = llama_kv_self_used_cells(ctx); + int n_ctx_used = llama_memory_seq_pos_max(llama_get_memory(ctx), 0); if (n_ctx_used + batch.n_tokens > n_ctx) { printf("\033[0m\n"); fprintf(stderr, "context size exceeded\n"); diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 10e79a0a6..633b87e58 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -84,13 +84,13 @@ int main(int argc, char ** argv) { model_params.n_gpu_layers = ngl; llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params); - const llama_vocab * vocab = llama_model_get_vocab(model); if (model == NULL) { fprintf(stderr , "%s: error: unable to load model\n" , __func__); return 1; } + const llama_vocab * vocab = llama_model_get_vocab(model); // tokenize the prompt // find the number of tokens in the prompt diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 0783ed4a4..99196c9d0 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -217,7 +217,7 @@ int main(int argc, char ** argv) { { LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); - llama_kv_self_seq_rm(ctx_tgt, 0, n_past, -1); + llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, n_past, -1); } if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) { diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 561c30883..0adffdb00 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -142,6 +142,8 @@ int main(int argc, char ** argv) { } } + auto * mem_tgt = llama_get_memory(ctx_tgt); + auto * mem_dft = llama_get_memory(ctx_dft); // Tokenize the prompt std::vector inp; @@ -420,14 +422,14 @@ int main(int argc, char ** argv) { { LOG_DBG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft); - llama_kv_self_seq_keep(ctx_dft, s_keep); - llama_kv_self_seq_cp (ctx_dft, s_keep, 0, -1, -1); - llama_kv_self_seq_keep(ctx_dft, 0); + llama_memory_seq_keep(mem_dft, s_keep); + llama_memory_seq_cp (mem_dft, s_keep, 0, -1, -1); + llama_memory_seq_keep(mem_dft, 0); - llama_kv_self_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1); - llama_kv_self_seq_keep(ctx_tgt, s_keep); - llama_kv_self_seq_cp (ctx_tgt, s_keep, 0, -1, -1); - llama_kv_self_seq_keep(ctx_tgt, 0); + llama_memory_seq_rm (mem_tgt, s_keep, n_past_tgt, -1); + llama_memory_seq_keep(mem_tgt, s_keep); + llama_memory_seq_cp (mem_tgt, s_keep, 0, -1, -1); + llama_memory_seq_keep(mem_tgt, 0); } for (int s = 0; s < n_seq_dft; ++s) { @@ -444,7 +446,7 @@ int main(int argc, char ** argv) { common_batch_clear(batch_dft); common_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true); - llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1); + llama_memory_seq_rm(mem_dft, 0, n_past_dft, -1); // LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); llama_decode(ctx_dft, batch_dft); @@ -503,8 +505,8 @@ int main(int argc, char ** argv) { if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_draft_split) { LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur); - llama_kv_self_seq_rm(ctx_dft, n_seq_cur, -1, -1); - llama_kv_self_seq_cp(ctx_dft, s, n_seq_cur, -1, -1); + llama_memory_seq_rm(mem_dft, n_seq_cur, -1, -1); + llama_memory_seq_cp(mem_dft, s, n_seq_cur, -1, -1); // all previous tokens from this branch are now also part of the new branch for (int t = 0; t < batch_tgt.n_tokens; ++t) { @@ -585,9 +587,9 @@ int main(int argc, char ** argv) { // evaluate the target model on the drafted tokens { - llama_kv_self_seq_keep(ctx_tgt, 0); + llama_memory_seq_keep(mem_tgt, 0); for (int s = 1; s < n_seq_dft; ++s) { - llama_kv_self_seq_cp(ctx_tgt, 0, s, -1, -1); + llama_memory_seq_cp(mem_tgt, 0, s, -1, -1); } // LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str()); diff --git a/examples/sycl/build.sh b/examples/sycl/build.sh index 8fe0a6790..e72b2e261 100755 --- a/examples/sycl/build.sh +++ b/examples/sycl/build.sh @@ -8,10 +8,10 @@ cd build source /opt/intel/oneapi/setvars.sh #for FP16 -#cmake .. -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON # faster for long-prompt inference +#cmake .. -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON -DLLAMA_CURL=OFF # faster for long-prompt inference #for FP32 -cmake .. -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx +cmake .. -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DLLAMA_CURL=OFF #build example/main #cmake --build . --config Release --target main diff --git a/examples/sycl/run-llama2.sh b/examples/sycl/run-llama2.sh index 8ce71e370..40ce8f5b2 100755 --- a/examples/sycl/run-llama2.sh +++ b/examples/sycl/run-llama2.sh @@ -12,16 +12,16 @@ source /opt/intel/oneapi/setvars.sh INPUT_PROMPT="Building a website can be done in 10 simple steps:\nStep 1:" MODEL_FILE=models/llama-2-7b.Q4_0.gguf -NGL=33 -CONEXT=4096 +NGL=99 +CONTEXT=4096 if [ $# -gt 0 ]; then GGML_SYCL_DEVICE=$1 echo "use $GGML_SYCL_DEVICE as main GPU" #use signle GPU only - ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONEXT} -mg $GGML_SYCL_DEVICE -sm none + ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none else #use multiple GPUs with same max compute units - ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONEXT} + ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} fi diff --git a/examples/sycl/run-llama3.sh b/examples/sycl/run-llama3.sh new file mode 100755 index 000000000..933d1b98b --- /dev/null +++ b/examples/sycl/run-llama3.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +# MIT license +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: MIT + +# If you want more control, DPC++ Allows selecting a specific device through the +# following environment variable +#export ONEAPI_DEVICE_SELECTOR="level_zero:0" +source /opt/intel/oneapi/setvars.sh + +#export GGML_SYCL_DEBUG=1 + +#ZES_ENABLE_SYSMAN=1, Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory. Recommended to use when --split-mode = layer. + +INPUT_PROMPT="Building a website can be done in 10 simple steps:\nStep 1:" +MODEL_FILE=models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf +NGL=99 # Layers offloaded to the GPU. If the device runs out of memory, reduce this value according to the model you are using. +CONTEXT=4096 + +if [ $# -gt 0 ]; then + GGML_SYCL_DEVICE=$1 + echo "Using $GGML_SYCL_DEVICE as the main GPU" + ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none +else + #use multiple GPUs with same max compute units + ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -c ${CONTEXT} +fi diff --git a/examples/sycl/win-run-llama2.bat b/examples/sycl/win-run-llama2.bat index c2918d6dc..d7564f416 100644 --- a/examples/sycl/win-run-llama2.bat +++ b/examples/sycl/win-run-llama2.bat @@ -6,4 +6,4 @@ set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:" @call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force -.\build\bin\llama-cli.exe -m models\llama-2-7b.Q4_0.gguf -p %INPUT2% -n 400 -e -ngl 33 -s 0 +.\build\bin\llama-cli.exe -m models\llama-2-7b.Q4_0.gguf -p %INPUT2% -n 400 -e -ngl 99 -s 0 diff --git a/examples/sycl/win-run-llama3.bat b/examples/sycl/win-run-llama3.bat new file mode 100644 index 000000000..4b61aebee --- /dev/null +++ b/examples/sycl/win-run-llama3.bat @@ -0,0 +1,9 @@ +:: MIT license +:: Copyright (C) 2024 Intel Corporation +:: SPDX-License-Identifier: MIT + +set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:" +@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force + + +.\build\bin\llama-cli.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -p %INPUT2% -n 400 -e -ngl 99 diff --git a/examples/training/CMakeLists.txt b/examples/training/CMakeLists.txt new file mode 100644 index 000000000..64afe6ddc --- /dev/null +++ b/examples/training/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-finetune) +add_executable(${TARGET} finetune.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/examples/training/README.md b/examples/training/README.md new file mode 100644 index 000000000..df4252792 --- /dev/null +++ b/examples/training/README.md @@ -0,0 +1,17 @@ +# llama.cpp/examples/training + +This directory contains examples related to language model training using llama.cpp/GGML. +So far finetuning is technically functional (for FP32 models and limited hardware setups) but the code is very much WIP. +Finetuning of Stories 260K and LLaMA 3.2 1b seems to work with 24 GB of memory. +**For CPU training, compile llama.cpp without any additional backends such as CUDA.** +**For CUDA training, use the maximum number of GPU layers.** + +Proof of concept: + +``` sh +export model_name=llama_3.2-1b && export quantization=f32 +./build/bin/llama-finetune --file wikitext-2-raw/wiki.test.raw -ngl 999 --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512 +./build/bin/llama-perplexity --file wikitext-2-raw/wiki.test.raw -ngl 999 --model finetuned-model.gguf +``` + +The perplexity value of the finetuned model should be lower after training on the test set for 2 epochs. diff --git a/examples/training/finetune.cpp b/examples/training/finetune.cpp new file mode 100644 index 000000000..23bede49b --- /dev/null +++ b/examples/training/finetune.cpp @@ -0,0 +1,96 @@ +#include "arg.h" +#include "common.h" +#include "log.h" +#include "llama.h" + +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +int main(int argc, char ** argv) { + common_params params; + + params.escape = false; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) { + return 1; + } + + if (params.use_mmap) { + LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__); + params.use_mmap = false; + } + if (params.cache_type_k != GGML_TYPE_F32) { + LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); + params.cache_type_k = GGML_TYPE_F32; + } + if (params.cache_type_v != GGML_TYPE_F32) { + LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); + params.cache_type_v = GGML_TYPE_F32; + } + + common_init(); + llama_backend_init(); + llama_numa_init(params.numa); + + // load the model and apply lora adapter, if any + common_init_result llama_init = common_init_from_params(params); + llama_model_ptr & model = llama_init.model; + llama_context_ptr & ctx = llama_init.context; + + if (model == NULL) { + LOG_ERR("%s: unable to load model\n", __func__); + return 1; + } + + // print system information + { + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + } + + constexpr float val_split = 0.05f; + + std::vector tokens = common_tokenize(ctx.get(), params.prompt, true); + ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2); + + struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr); + optimizer_params.adamw.alpha = 1e-7f; // learning rate + + struct llama_opt_params lopt_params { + /*n_ctx_train =*/ 0, + /*param_filter =*/ llama_opt_param_filter_all, + /*param_filter_ud =*/ nullptr, + /*get_opt_pars =*/ ggml_opt_get_constant_optimizer_params, + /*get_opt_pars_ud =*/ &optimizer_params, + }; + llama_opt_init(ctx.get(), model.get(), lopt_params); + + const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - val_split); + + ggml_opt_result_t result_train = ggml_opt_result_init(); + ggml_opt_result_t result_eval = ggml_opt_result_init(); + + for (int epoch = 0; epoch < 2; ++epoch) { + llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split, + ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar); + fprintf(stderr, "\n"); + + ggml_opt_result_reset(result_train); + ggml_opt_result_reset(result_eval); + } + ggml_opt_result_free(result_train); + ggml_opt_result_free(result_eval); + + llama_model_save_to_file(model.get(), "finetuned-model.gguf"); + + llama_backend_free(); + + return 0; +} diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index d33f843b4..4e7399f9e 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -105,8 +105,9 @@ message(DEBUG "GGML_NATIVE_DEFAULT : ${GGML_NATIVE_DEFAULT}") message(DEBUG "INS_ENB : ${INS_ENB}") option(GGML_CPU_HBM "ggml: use memkind for CPU HBM" OFF) -option(GGML_CPU_AARCH64 "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON) +option(GGML_CPU_REPACK "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON) option(GGML_CPU_KLEIDIAI "ggml: use KleidiAI optimized kernels if applicable" OFF) +option(GGML_SSE42 "ggml: enable SSE 4.2" ${INS_ENB}) option(GGML_AVX "ggml: enable AVX" ${INS_ENB}) option(GGML_AVX_VNNI "ggml: enable AVX-VNNI" OFF) option(GGML_AVX2 "ggml: enable AVX2" ${INS_ENB}) @@ -128,6 +129,7 @@ option(GGML_LASX "ggml: enable lasx" ON) option(GGML_LSX "ggml: enable lsx" ON) option(GGML_RVV "ggml: enable rvv" ON) option(GGML_RV_ZFH "ggml: enable riscv zfh" OFF) +option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF) option(GGML_VXE "ggml: enable vxe" ON) option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF) @@ -135,7 +137,7 @@ set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM") set(GGML_CPU_POWERPC_CPUTYPE "" CACHE STRING "ggml: CPU type for PowerPC") -if (WIN32) +if (MINGW) set(GGML_WIN_VER "0x602" CACHE STRING "ggml: Windows version") endif() @@ -170,13 +172,12 @@ option(GGML_HIP "ggml: use HIP" option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF) option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON) option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF) -option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF) +option(GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 "ggml: enable rocWMMA FlashAttention on GFX12" OFF) option(GGML_VULKAN "ggml: use Vulkan" OFF) option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF) option(GGML_VULKAN_DEBUG "ggml: enable Vulkan debug output" OFF) option(GGML_VULKAN_MEMORY_DEBUG "ggml: enable Vulkan memory debug output" OFF) option(GGML_VULKAN_SHADER_DEBUG_INFO "ggml: enable Vulkan shader debug info" OFF) -option(GGML_VULKAN_PERF "ggml: enable Vulkan perf output" OFF) option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" OFF) option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF) option(GGML_KOMPUTE "ggml: use Kompute" OFF) @@ -193,6 +194,7 @@ option(GGML_RPC "ggml: use RPC" option(GGML_SYCL "ggml: use SYCL" OFF) option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF) option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON) +option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON) set (GGML_SYCL_TARGET "INTEL" CACHE STRING "ggml: sycl target device") set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING @@ -360,3 +362,73 @@ write_basic_package_version_file( install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml-config.cmake ${CMAKE_CURRENT_BINARY_DIR}/ggml-version.cmake DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ggml) + +if (MSVC) + set(MSVC_WARNING_FLAGS + /wd4005 # Macro redefinition + /wd4244 # Conversion from one type to another type, possible loss of data + /wd4267 # Conversion from 'size_t' to a smaller type, possible loss of data + /wd4305 # Conversion from 'type1' to 'type2', possible loss of data + /wd4566 # Conversion from 'char' to 'wchar_t', possible loss of data + /wd4996 # Disable POSIX deprecation warnings + /wd4702 # Unreachable code warnings + ) + function(disable_msvc_warnings target_name) + if(TARGET ${target_name}) + target_compile_options(${target_name} PRIVATE ${MSVC_WARNING_FLAGS}) + endif() + endfunction() + + disable_msvc_warnings(ggml-base) + disable_msvc_warnings(ggml) + disable_msvc_warnings(ggml-cpu) + disable_msvc_warnings(ggml-cpu-x64) + disable_msvc_warnings(ggml-cpu-sse42) + disable_msvc_warnings(ggml-cpu-sandybridge) + disable_msvc_warnings(ggml-cpu-haswell) + disable_msvc_warnings(ggml-cpu-skylakex) + disable_msvc_warnings(ggml-cpu-icelake) + disable_msvc_warnings(ggml-cpu-alderlake) + + if (GGML_BUILD_EXAMPLES) + disable_msvc_warnings(common-ggml) + disable_msvc_warnings(common) + + disable_msvc_warnings(mnist-common) + disable_msvc_warnings(mnist-eval) + disable_msvc_warnings(mnist-train) + + disable_msvc_warnings(gpt-2-ctx) + disable_msvc_warnings(gpt-2-alloc) + disable_msvc_warnings(gpt-2-backend) + disable_msvc_warnings(gpt-2-sched) + disable_msvc_warnings(gpt-2-quantize) + disable_msvc_warnings(gpt-2-batched) + + disable_msvc_warnings(gpt-j) + disable_msvc_warnings(gpt-j-quantize) + + disable_msvc_warnings(magika) + disable_msvc_warnings(yolov3-tiny) + disable_msvc_warnings(sam) + + disable_msvc_warnings(simple-ctx) + disable_msvc_warnings(simple-backend) + endif() + + if (GGML_BUILD_TESTS) + disable_msvc_warnings(test-mul-mat) + disable_msvc_warnings(test-arange) + disable_msvc_warnings(test-backend-ops) + disable_msvc_warnings(test-cont) + disable_msvc_warnings(test-conv-transpose) + disable_msvc_warnings(test-conv-transpose-1d) + disable_msvc_warnings(test-conv1d) + disable_msvc_warnings(test-conv2d) + disable_msvc_warnings(test-conv2d-dw) + disable_msvc_warnings(test-customop) + disable_msvc_warnings(test-dup) + disable_msvc_warnings(test-opt) + disable_msvc_warnings(test-pool) + endif () +endif() diff --git a/ggml/cmake/common.cmake b/ggml/cmake/common.cmake index 1976d0ae9..cb6638833 100644 --- a/ggml/cmake/common.cmake +++ b/ggml/cmake/common.cmake @@ -24,3 +24,27 @@ function(ggml_get_flags CCID CCVER) set(GF_C_FLAGS ${C_FLAGS} PARENT_SCOPE) set(GF_CXX_FLAGS ${CXX_FLAGS} PARENT_SCOPE) endfunction() + +function(ggml_get_system_arch) + if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR + CMAKE_GENERATOR_PLATFORM_LWR STREQUAL "arm64" OR + (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND + CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm.*|ARM64)$")) + set(GGML_SYSTEM_ARCH "ARM" PARENT_SCOPE) + elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR + CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR + (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND + CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64|amd64)$")) + set(GGML_SYSTEM_ARCH "x86" PARENT_SCOPE) + elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc|power") + set(GGML_SYSTEM_ARCH "PowerPC" PARENT_SCOPE) + elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64") + set(GGML_SYSTEM_ARCH "loongarch64" PARENT_SCOPE) + elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "riscv64") + set(GGML_SYSTEM_ARCH "riscv64" PARENT_SCOPE) + elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x") + set(GGML_SYSTEM_ARCH "s390x" PARENT_SCOPE) + else() + set(GGML_SYSTEM_ARCH "UNKNOWN" PARENT_SCOPE) + endif() +endfunction() diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 64671495b..778927f68 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -38,7 +38,7 @@ extern "C" { GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size); GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft); GGML_API size_t ggml_backend_buft_get_max_size (ggml_backend_buffer_type_t buft); - GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor); + GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor); GGML_API bool ggml_backend_buft_is_host (ggml_backend_buffer_type_t buft); GGML_API ggml_backend_dev_t ggml_backend_buft_get_device (ggml_backend_buffer_type_t buft); @@ -59,7 +59,7 @@ extern "C" { GGML_API enum ggml_status ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer); GGML_API size_t ggml_backend_buffer_get_max_size (ggml_backend_buffer_t buffer); - GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor); GGML_API void ggml_backend_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value); GGML_API bool ggml_backend_buffer_is_host (ggml_backend_buffer_t buffer); GGML_API void ggml_backend_buffer_set_usage (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage); @@ -248,7 +248,7 @@ extern "C" { // preferrably to run on the same backend as the buffer ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); - sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false); + sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false, true); // initialize buffers from a max size graph (optional) reserve_graph = build_graph(sched, max_batch_size); @@ -289,7 +289,7 @@ extern "C" { typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data); // Initialize a backend scheduler, backends with low index are given priority over backends with high index - GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel); + GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload); GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched); // Initialize backend buffers from a measure graph diff --git a/ggml/include/ggml-cpp.h b/ggml/include/ggml-cpp.h index a12342c25..48aa79682 100644 --- a/ggml/include/ggml-cpp.h +++ b/ggml/include/ggml-cpp.h @@ -24,7 +24,7 @@ typedef std::unique_ptr gguf_context_ptr; struct ggml_gallocr_deleter { void operator()(ggml_gallocr_t galloc) { ggml_gallocr_free(galloc); } }; -typedef std::unique_ptr ggml_gallocr_ptr; +typedef std::unique_ptr ggml_gallocr_ptr; // ggml-backend diff --git a/ggml/include/ggml-cpu.h b/ggml/include/ggml-cpu.h index f5e11f1e1..de77a875e 100644 --- a/ggml/include/ggml-cpu.h +++ b/ggml/include/ggml-cpu.h @@ -133,6 +133,11 @@ extern "C" { GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void); + GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t); + GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t); + GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t); + GGML_BACKEND_API void ggml_cpu_bf16_to_fp32(const ggml_bf16_t *, float *, int64_t); + #ifdef __cplusplus } #endif diff --git a/ggml/include/ggml-opt.h b/ggml/include/ggml-opt.h index eb5eab9de..74ec080a0 100644 --- a/ggml/include/ggml-opt.h +++ b/ggml/include/ggml-opt.h @@ -37,13 +37,16 @@ extern "C" { // ====== Dataset ====== GGML_API ggml_opt_dataset_t ggml_opt_dataset_init( - int64_t ne_datapoint, // number of elements per datapoint - int64_t ne_label, // number of elements per label - int64_t ndata, // total number of datapoints/labels - int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied) + enum ggml_type type_data, // the type for the internal data tensor + enum ggml_type type_label, // the type for the internal labels tensor + int64_t ne_datapoint, // number of elements per datapoint + int64_t ne_label, // number of elements per label + int64_t ndata, // total number of datapoints/labels + int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied) GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset); // get underlying tensors that store the data + GGML_API int64_t ggml_opt_dataset_ndata (ggml_opt_dataset_t dataset); GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata] GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label, ndata] @@ -56,13 +59,19 @@ extern "C" { struct ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch] struct ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch] int64_t ibatch); + GGML_API void ggml_opt_dataset_get_batch_host( + ggml_opt_dataset_t dataset, + void * data_batch, + size_t nb_data_batch, + void * labels_batch, + int64_t ibatch); // ====== Model / Context ====== enum ggml_opt_build_type { - GGML_OPT_BUILD_TYPE_FORWARD, - GGML_OPT_BUILD_TYPE_GRAD, - GGML_OPT_BUILD_TYPE_OPT, + GGML_OPT_BUILD_TYPE_FORWARD = 10, + GGML_OPT_BUILD_TYPE_GRAD = 20, + GGML_OPT_BUILD_TYPE_OPT = 30, }; // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss @@ -81,20 +90,22 @@ extern "C" { // userdata can be used to pass arbitrary data typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata); - // returns the default optimizer params (constant) + // returns the default optimizer params (constant, hard-coded values) // userdata is not used GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata); + // casts userdata to ggml_opt_optimizer_params and returns it + GGML_API struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata); + // parameters for initializing a new optimization context struct ggml_opt_params { ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs - struct ggml_context * ctx_compute; // created in user code, holds non-static tensors - - // the forward graph is defined by inputs and outputs - // those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts - struct ggml_tensor * inputs; - struct ggml_tensor * outputs; + // by default the forward graph needs to be reconstructed for each eval + // if ctx_compute, inputs, and outputs are set the graphs are instead allocated statically + struct ggml_context * ctx_compute; + struct ggml_tensor * inputs; + struct ggml_tensor * outputs; enum ggml_opt_loss_type loss_type; enum ggml_opt_build_type build_type; @@ -107,12 +118,9 @@ extern "C" { // get parameters for an optimization context with defaults set where possible // parameters for which no sensible defaults exist are supplied as arguments to this function - GGML_API ggml_opt_params ggml_opt_default_params( - ggml_backend_sched_t backend_sched, - struct ggml_context * ctx_compute, - struct ggml_tensor * inputs, - struct ggml_tensor * outputs, - enum ggml_opt_loss_type loss_type); + GGML_API struct ggml_opt_params ggml_opt_default_params( + ggml_backend_sched_t backend_sched, + enum ggml_opt_loss_type loss_type); GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params); GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx); @@ -120,7 +128,10 @@ extern "C" { // set gradients to zero, initilize loss, and optionally reset the optimizer GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer); + GGML_API bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx); // whether the graphs are allocated_statically + // get underlying tensors that store data + // if not using static graphs these pointers become invalid with the next call to ggml_opt_alloc GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor GGML_API struct ggml_tensor * ggml_opt_labels( ggml_opt_context_t opt_ctx); // labels to compare outputs against @@ -128,11 +139,12 @@ extern "C" { GGML_API struct ggml_tensor * ggml_opt_pred( ggml_opt_context_t opt_ctx); // predictions made by outputs GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels + // get the gradient accumulator for a node from the forward graph GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node); // ====== Optimization Result ====== - GGML_API ggml_opt_result_t ggml_opt_result_init(); + GGML_API ggml_opt_result_t ggml_opt_result_init(void); GGML_API void ggml_opt_result_free(ggml_opt_result_t result); GGML_API void ggml_opt_result_reset(ggml_opt_result_t result); @@ -144,11 +156,20 @@ extern "C" { // ====== Computation ====== - // do forward pass, increment result if not NULL - GGML_API void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result); + // if not using static graphs, this function must be called prior to ggml_opt_alloc + GGML_API void ggml_opt_prepare_alloc( + ggml_opt_context_t opt_ctx, + struct ggml_context * ctx_compute, + struct ggml_cgraph * gf, + struct ggml_tensor * inputs, + struct ggml_tensor * outputs); - // do forward pass, increment result if not NULL, do backward pass - GGML_API void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result); + // allocate the next graph for evaluation, either forward or forward + backward + // must be called exactly once prior to calling ggml_opt_eval + GGML_API void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward); + + // do forward pass, increment result if not NULL, do backward pass if allocated + GGML_API void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result); // ############################################################################ // ## The high-level functions start here. They do not depend on any private ## @@ -200,9 +221,9 @@ extern "C" { // fit model defined by inputs and outputs to dataset GGML_API void ggml_opt_fit( ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs - ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs - ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch] - ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used + struct ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs + struct ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch] + struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used ggml_opt_dataset_t dataset, // dataset with data and optionally also labels enum ggml_opt_loss_type loss_type, // loss to minimize ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t) diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index 4e0d210f8..1e6741127 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -7,6 +7,9 @@ extern "C" { #endif +#define RPC_PROTO_MAJOR_VERSION 2 +#define RPC_PROTO_MINOR_VERSION 0 +#define RPC_PROTO_PATCH_VERSION 0 #define GGML_RPC_MAX_SERVERS 16 // backend API diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 8fcc16df9..1a57f1cd7 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -393,8 +393,8 @@ extern "C" { // precision enum ggml_prec { - GGML_PREC_DEFAULT, - GGML_PREC_F32, + GGML_PREC_DEFAULT = 0, // stored as ggml_tensor.op_params, 0 by default + GGML_PREC_F32 = 10, }; // model file types @@ -481,6 +481,7 @@ extern "C" { GGML_OP_CONV_TRANSPOSE_1D, GGML_OP_IM2COL, GGML_OP_IM2COL_BACK, + GGML_OP_CONV_2D_DW, GGML_OP_CONV_TRANSPOSE_2D, GGML_OP_POOL_1D, GGML_OP_POOL_2D, @@ -535,6 +536,7 @@ extern "C" { GGML_UNARY_OP_HARDSWISH, GGML_UNARY_OP_HARDSIGMOID, GGML_UNARY_OP_EXP, + GGML_UNARY_OP_GELU_ERF, GGML_UNARY_OP_COUNT, }; @@ -672,11 +674,18 @@ extern "C" { GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor); GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars + // returns whether the tensor elements can be iterated over with a flattened index (no gaps, no permutation) GGML_API bool ggml_is_contiguous (const struct ggml_tensor * tensor); GGML_API bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous() GGML_API bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1 GGML_API bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2 + // returns whether the tensor elements are allocated as one contiguous block of memory (no gaps, but permutation ok) + GGML_API bool ggml_is_contiguously_allocated(const struct ggml_tensor * tensor); + + // true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN + GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor); + GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1); GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1); @@ -760,7 +769,7 @@ extern "C" { // Tensor flags GGML_API void ggml_set_input(struct ggml_tensor * tensor); GGML_API void ggml_set_output(struct ggml_tensor * tensor); - GGML_API void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor); + GGML_API void ggml_set_param(struct ggml_tensor * tensor); GGML_API void ggml_set_loss(struct ggml_tensor * tensor); // @@ -926,11 +935,20 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + // repeat a to the specified shape + GGML_API struct ggml_tensor * ggml_repeat_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3); + // sums repetitions in a into shape of b GGML_API struct ggml_tensor * ggml_repeat_back( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b); // sum up values that are adjacent in dims > 0 instead of repeated with same stride // concat a and b along dim // used in stable-diffusion @@ -1016,6 +1034,16 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // GELU using erf (error function) when possible + // some backends may fallback to approximation based on Abramowitz and Stegun formula + GGML_API struct ggml_tensor * ggml_gelu_erf( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_gelu_erf_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + GGML_API struct ggml_tensor * ggml_gelu_quick( struct ggml_context * ctx, struct ggml_tensor * a); @@ -1660,7 +1688,7 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); - // depthwise + // depthwise (via im2col and mul_mat) GGML_API struct ggml_tensor * ggml_conv_2d_dw( struct ggml_context * ctx, struct ggml_tensor * a, // convolution kernel @@ -1672,6 +1700,22 @@ extern "C" { int d0, // dilation dimension 0 int d1); // dilation dimension 1 + // Depthwise 2D convolution + // may be faster than ggml_conv_2d_dw, but not available in all backends + // a: KW KH 1 C convolution kernel + // b: W H C N input data + // res: W_out H_out C N + GGML_API struct ggml_tensor * ggml_conv_2d_dw_direct( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int stride0, + int stride1, + int pad0, + int pad1, + int dilation0, + int dilation1); + GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0( struct ggml_context * ctx, struct ggml_tensor * a, @@ -2025,15 +2069,14 @@ extern "C" { GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); GGML_API void ggml_build_backward_expand( - struct ggml_context * ctx_static, // context for static gradients (loss + gradient accumulation) - struct ggml_context * ctx_compute, // context for gradient computation - struct ggml_cgraph * cgraph, - bool accumulate); // whether or not gradients should be accumulated, requires static allocation of tensors in ctx_static + struct ggml_context * ctx, // context for gradient computation + struct ggml_cgraph * cgraph, + struct ggml_tensor ** grad_accs); // graph allocation in a context GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false GGML_API struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads); - GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph); + GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads); GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst); GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1 GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph); @@ -2052,9 +2095,6 @@ extern "C" { GGML_API struct ggml_tensor * ggml_graph_get_grad (const struct ggml_cgraph * cgraph, const struct ggml_tensor * node); GGML_API struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node); - GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname); - GGML_API struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval); - // print info and performance information for the graph GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph); @@ -2138,6 +2178,7 @@ extern "C" { // scheduling priorities enum ggml_sched_priority { + GGML_SCHED_PRIO_LOW = -1, GGML_SCHED_PRIO_NORMAL, GGML_SCHED_PRIO_MEDIUM, GGML_SCHED_PRIO_HIGH, diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index f00700da7..0c453741b 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -109,6 +109,8 @@ if (MSVC) else () set(CMAKE_GENERATOR_PLATFORM_LWR "") endif () +ggml_get_system_arch() +message(STATUS "GGML_SYSTEM_ARCH: ${GGML_SYSTEM_ARCH}") if (NOT MSVC) if (GGML_STATIC) @@ -123,7 +125,6 @@ if (NOT MSVC) endif() if (MINGW) - # Target Windows 8 for PrefetchVirtualMemory add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER}) endif() @@ -194,6 +195,7 @@ add_library(ggml-base ../include/ggml-opt.h ../include/gguf.h ggml.c + ggml.cpp ggml-alloc.c ggml-backend.cpp ggml-opt.cpp @@ -210,11 +212,12 @@ endif() add_library(ggml ggml-backend-reg.cpp) +add_library(ggml::ggml ALIAS ggml) target_link_libraries(ggml PUBLIC ggml-base) if (CMAKE_SYSTEM_NAME MATCHES "Linux") - target_link_libraries(ggml PRIVATE dl stdc++fs) + target_link_libraries(ggml PRIVATE dl) endif() function(ggml_add_backend_library backend) @@ -224,6 +227,7 @@ function(ggml_add_backend_library backend) set_target_properties(${backend} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) target_compile_definitions(${backend} PRIVATE GGML_BACKEND_DL) add_dependencies(ggml ${backend}) + install(TARGETS ${backend} LIBRARY DESTINATION ${CMAKE_INSTALL_BINDIR}) else() add_library(${backend} ${ARGN}) target_link_libraries(ggml PUBLIC ${backend}) @@ -266,16 +270,23 @@ endfunction() function(ggml_add_cpu_backend_variant tag_name) set(GGML_CPU_TAG_NAME ${tag_name}) # other: OPENMP LLAMAFILE CPU_HBM - foreach (feat NATIVE - AVX AVX2 BMI2 AVX_VNNI FMA F16C - AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 - AMX_TILE AMX_INT8 AMX_BF16) - set(GGML_${feat} OFF) - endforeach() + if (GGML_SYSTEM_ARCH STREQUAL "x86") + foreach (feat NATIVE + SSE42 + AVX AVX2 BMI2 AVX_VNNI FMA F16C + AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 + AMX_TILE AMX_INT8 AMX_BF16) + set(GGML_${feat} OFF) + endforeach() - foreach (feat ${ARGN}) - set(GGML_${feat} ON) - endforeach() + foreach (feat ${ARGN}) + set(GGML_${feat} ON) + endforeach() + elseif (GGML_SYSTEM_ARCH STREQUAL "ARM") + foreach (feat ${ARGN}) + set(GGML_INTERNAL_${feat} ON) + endforeach() + endif() ggml_add_cpu_backend_variant_impl(${tag_name}) endfunction() @@ -285,15 +296,49 @@ ggml_add_backend(CPU) if (GGML_CPU_ALL_VARIANTS) if (NOT GGML_BACKEND_DL) message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS requires GGML_BACKEND_DL") + elseif (GGML_CPU_ARM_ARCH) + message(FATAL_ERROR "Cannot use both GGML_CPU_ARM_ARCH and GGML_CPU_ALL_VARIANTS") endif() - ggml_add_cpu_backend_variant(sandybridge AVX) - ggml_add_cpu_backend_variant(haswell AVX F16C AVX2 BMI2 FMA) - ggml_add_cpu_backend_variant(skylakex AVX F16C AVX2 BMI2 FMA AVX512) - ggml_add_cpu_backend_variant(icelake AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI) - ggml_add_cpu_backend_variant(alderlake AVX F16C AVX2 BMI2 FMA AVX_VNNI) - if (NOT MSVC) - # MSVC doesn't support AMX - ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8) + if (GGML_SYSTEM_ARCH STREQUAL "x86") + ggml_add_cpu_backend_variant(x64) + ggml_add_cpu_backend_variant(sse42 SSE42) + ggml_add_cpu_backend_variant(sandybridge SSE42 AVX) + ggml_add_cpu_backend_variant(haswell SSE42 AVX F16C AVX2 BMI2 FMA) + ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C AVX2 BMI2 FMA AVX512) + ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI) + ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI) + if (NOT MSVC) + # MSVC doesn't support AMX + ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8) + endif() + elseif(GGML_SYSTEM_ARCH STREQUAL "ARM") + if (CMAKE_SYSTEM_NAME MATCHES "Linux") + # Many of these features are optional so we build versions with popular + # combinations and name the backends based on the version they were + # first released with + ggml_add_cpu_backend_variant(armv8.0_1) + ggml_add_cpu_backend_variant(armv8.2_1 DOTPROD) + ggml_add_cpu_backend_variant(armv8.2_2 DOTPROD FP16_VECTOR_ARITHMETIC) + ggml_add_cpu_backend_variant(armv8.2_3 DOTPROD FP16_VECTOR_ARITHMETIC SVE) + ggml_add_cpu_backend_variant(armv8.6_1 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8) + ggml_add_cpu_backend_variant(armv8.6_2 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8 SVE2) + ggml_add_cpu_backend_variant(armv9.2_1 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8 SME) + ggml_add_cpu_backend_variant(armv9.2_2 DOTPROD FP16_VECTOR_ARITHMETIC SVE MATMUL_INT8 SVE2 SME) + elseif (CMAKE_SYSTEM_NAME MATCHES "Android") + # Android-specific backends with SoC-compatible feature sets + ggml_add_cpu_backend_variant(android_armv8.0_1) + ggml_add_cpu_backend_variant(android_armv8.2_1 DOTPROD) + ggml_add_cpu_backend_variant(android_armv8.2_2 DOTPROD FP16_VECTOR_ARITHMETIC) + ggml_add_cpu_backend_variant(android_armv8.6_1 DOTPROD FP16_VECTOR_ARITHMETIC MATMUL_INT8) + elseif (APPLE) + ggml_add_cpu_backend_variant(apple_m1 DOTPROD) + ggml_add_cpu_backend_variant(apple_m2_m3 DOTPROD MATMUL_INT8) + ggml_add_cpu_backend_variant(apple_m4 DOTPROD MATMUL_INT8 NOSVE SME) + else() + message(FATAL_ERROR "Unsupported ARM target OS: ${CMAKE_SYSTEM_NAME}") + endif() + else() + message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS not yet supported with ${GGML_SYSTEM_ARCH} on ${CMAKE_SYSTEM_NAME}") endif() elseif (GGML_CPU) ggml_add_cpu_backend_variant_impl("") diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index a3d3f6901..5fd379f6a 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -816,7 +816,10 @@ static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor * static bool ggml_gallocr_node_needs_realloc(ggml_gallocr_t galloc, struct ggml_tensor * node, struct tensor_alloc * talloc) { size_t node_size = 0; if (!node->data && !node->view_src) { - GGML_ASSERT(talloc->buffer_id >= 0); // prevent segfault when misusing the API + // If we previously had data but don't now then reallocate + if (talloc->buffer_id < 0) { + return false; + } node_size = ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node); } return talloc->size_max >= node_size; diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 273075f4e..b1050ad59 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -56,7 +56,7 @@ size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) { return SIZE_MAX; } -size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor) { +size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) { // get_alloc_size is optional, defaults to ggml_nbytes if (buft->iface.get_alloc_size) { size_t size = buft->iface.get_alloc_size(buft, tensor); @@ -152,7 +152,7 @@ size_t ggml_backend_buffer_get_max_size(ggml_backend_buffer_t buffer) { return ggml_backend_buft_get_max_size(ggml_backend_buffer_get_type(buffer)); } -size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { +size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor) { return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_get_type(buffer), tensor); } @@ -674,6 +674,8 @@ struct ggml_backend_sched { char * context_buffer; size_t context_buffer_size; + bool op_offload; + int debug; }; @@ -766,7 +768,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) { int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor); // check if a backend with higher prio wants to offload the op - if (src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) { + if (sched->op_offload && src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) { for (int b = 0; b < src_backend_id; b++) { if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) { SET_CAUSE(tensor, "1.off"); @@ -1109,7 +1111,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg const int node_backend_id = tensor_backend_id(node); - assert(node_backend_id != -1); // all nodes should be assigned by now + assert(node_backend_id != -1); // all nodes should be assigned by now, this can happen if there is no CPU fallback // check if we should start a new split based on the sources of the current node bool need_new_split = false; @@ -1338,7 +1340,10 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) { // allocate graph if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) { // the re-allocation may cause the split inputs to be moved to a different address - ggml_backend_sched_synchronize(sched); + // synchronize without ggml_backend_sched_synchronize to avoid changing cur_copy + for (int i = 0; i < sched->n_backends; i++) { + ggml_backend_synchronize(sched->backends[i]); + } #ifndef NDEBUG GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed); #endif @@ -1452,7 +1457,8 @@ ggml_backend_sched_t ggml_backend_sched_new( ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, - bool parallel) { + bool parallel, + bool op_offload) { GGML_ASSERT(n_backends > 0); GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS); GGML_ASSERT(ggml_backend_dev_type(ggml_backend_get_device(backends[n_backends - 1])) == GGML_BACKEND_DEVICE_TYPE_CPU); @@ -1497,6 +1503,7 @@ ggml_backend_sched_t ggml_backend_sched_new( } sched->galloc = ggml_gallocr_new_n(sched->bufts, n_backends); + sched->op_offload = op_offload; ggml_backend_sched_reset(sched); @@ -1560,7 +1567,6 @@ bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgra ggml_backend_sched_split_graph(sched, graph); - if (!ggml_backend_sched_alloc_splits(sched)) { return false; } @@ -1594,6 +1600,12 @@ void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) { for (int i = 0; i < sched->n_backends; i++) { ggml_backend_synchronize(sched->backends[i]); } + if (!sched->is_alloc) { + // if the graph is not already allocated, always use copy 0 after a synchronization + // this ensures that during generation the same copy is used every time, + // which avoids changes in the graph that could cause CUDA or other graphs to be disabled + sched->cur_copy = 0; + } } void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) { diff --git a/ggml/src/ggml-blas/CMakeLists.txt b/ggml/src/ggml-blas/CMakeLists.txt index 0bf3c05d9..76064c3fd 100644 --- a/ggml/src/ggml-blas/CMakeLists.txt +++ b/ggml/src/ggml-blas/CMakeLists.txt @@ -81,7 +81,7 @@ if (BLAS_FOUND) target_link_libraries (ggml-blas PRIVATE ${BLAS_LIBRARIES}) target_include_directories(ggml-blas PRIVATE ${BLAS_INCLUDE_DIRS}) else() - message(ERROR "BLAS not found, please refer to " - "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" - " to set correct GGML_BLAS_VENDOR") + message(FATAL_ERROR "BLAS not found, please refer to " + "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" + " to set correct GGML_BLAS_VENDOR") endif() diff --git a/ggml/src/ggml-cann/CMakeLists.txt b/ggml/src/ggml-cann/CMakeLists.txt old mode 100644 new mode 100755 index 0d8e483b2..7742b3915 --- a/ggml/src/ggml-cann/CMakeLists.txt +++ b/ggml/src/ggml-cann/CMakeLists.txt @@ -30,6 +30,7 @@ string(TOLOWER ${SOC_TYPE} SOC_VERSION) # SOC_VERSION need lower string(REGEX MATCH "[0-9]+[a-zA-Z]" SOC_TYPE_MAJOR_SN "${SOC_VERSION}") set(SOC_TYPE_COMPILE_OPTION "ASCEND_${SOC_TYPE_MAJOR_SN}") string(TOUPPER ${SOC_TYPE_COMPILE_OPTION} SOC_TYPE_COMPILE_OPTION) +message(STATUS "CANN: SOC_VERSION = ${SOC_VERSION}") if (CANN_INSTALL_DIR) # Only Support Linux. diff --git a/ggml/src/ggml-cann/Doxyfile b/ggml/src/ggml-cann/Doxyfile old mode 100644 new mode 100755 diff --git a/ggml/src/ggml-cann/acl_tensor.cpp b/ggml/src/ggml-cann/acl_tensor.cpp old mode 100644 new mode 100755 index f5462c5a1..f311864d4 --- a/ggml/src/ggml-cann/acl_tensor.cpp +++ b/ggml/src/ggml-cann/acl_tensor.cpp @@ -31,6 +31,8 @@ aclDataType ggml_cann_type_mapping(ggml_type type) { return ACL_FLOAT; case GGML_TYPE_F16: return ACL_FLOAT16; + case GGML_TYPE_BF16: + return ACL_BF16; case GGML_TYPE_I8: return ACL_INT8; case GGML_TYPE_I16: diff --git a/ggml/src/ggml-cann/acl_tensor.h b/ggml/src/ggml-cann/acl_tensor.h old mode 100644 new mode 100755 diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp old mode 100644 new mode 100755 index 37d411797..437ece2d4 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -64,6 +64,9 @@ #include #include #include +#include +#include +#include #include #include @@ -72,11 +75,13 @@ #include #include "ggml-impl.h" +#include "ggml.h" #define GGML_COMMON_DECL_C #include "../ggml-common.h" + void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclTensor ** acl_src0, aclTensor ** acl_src1, aclTensor ** acl_dst) { GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_can_repeat(src1, src0)); @@ -102,9 +107,7 @@ void ggml_cann_unary_op( aclTensor* acl_dst = ggml_cann_create_tensor(dst); unary_op(ctx, acl_src, acl_dst); - - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + ggml_cann_release_resources(ctx, acl_src, acl_dst); } /** @@ -122,8 +125,8 @@ static void aclnn_repeat(ggml_backend_cann_context& ctx, aclTensor* acl_src, // repeat tensor along each dim with repeat_array aclIntArray* repeats = aclCreateIntArray(repeat_array, GGML_MAX_DIMS); - GGML_CANN_CALL_ACLNN_OP(Repeat, acl_src, repeats, acl_dst); - ACL_CHECK(aclDestroyIntArray(repeats)); + GGML_CANN_CALL_ACLNN_OP(ctx, Repeat, acl_src, repeats, acl_dst); + ggml_cann_release_resources(ctx, repeats); } /** @@ -141,24 +144,7 @@ static void aclnn_repeat(ggml_backend_cann_context& ctx, aclTensor* acl_src, */ static void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst, aclDataType cast_data_type) { - GGML_CANN_CALL_ACLNN_OP(Cast, acl_src, cast_data_type, acl_dst); -} - -/** - * @brief Casts the elements of a tensor to a specified data type using the CANN backend. - * - * @details This function performs a type conversion on the elements of the input tensor `acl_src` - * and stores the results in the destination tensor `acl_dst`. The conversion type is - * determined based on the `dst` tensor's data type. - * - * @param ctx The context for the CANN backend operations. - * @param acl_src The source tensor whose elements will be cast. - * @param acl_dst The destination tensor that will store the casted elements. - * @param dst The ggml tensor specifying the target data type. - */ -static void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src, - aclTensor* acl_dst, ggml_tensor* dst) { - aclnn_cast(ctx, acl_src, acl_dst, ggml_cann_type_mapping(dst->type)); + GGML_CANN_CALL_ACLNN_OP(ctx, Cast, acl_src, cast_data_type, acl_dst); } void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -172,8 +158,7 @@ void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { dst->ne[1] / src->ne[1], dst->ne[0] / src->ne[0]}; aclnn_repeat(ctx, acl_src, acl_dst, repeatsArray); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + ggml_cann_release_resources(ctx, acl_src, acl_dst); } void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0, @@ -181,10 +166,10 @@ void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0, float alphaValue = 1.0f; aclScalar* alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); if (acl_dst != nullptr) - GGML_CANN_CALL_ACLNN_OP(Add, acl_src0, acl_src1, alpha, acl_dst); + GGML_CANN_CALL_ACLNN_OP(ctx, Add, acl_src0, acl_src1, alpha, acl_dst); else - GGML_CANN_CALL_ACLNN_OP(InplaceAdd, acl_src0, acl_src1, alpha); - ACL_CHECK(aclDestroyScalar(alpha)); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_src0, acl_src1, alpha); + ggml_cann_release_resources(ctx, alpha); } void aclnn_sub(ggml_backend_cann_context& ctx, aclTensor* acl_src0, @@ -192,26 +177,26 @@ void aclnn_sub(ggml_backend_cann_context& ctx, aclTensor* acl_src0, float alphaValue = 1.0f; aclScalar* alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); if (acl_dst != nullptr) - GGML_CANN_CALL_ACLNN_OP(Sub, acl_src0, acl_src1, alpha, acl_dst); + GGML_CANN_CALL_ACLNN_OP(ctx, Sub, acl_src0, acl_src1, alpha, acl_dst); else - GGML_CANN_CALL_ACLNN_OP(InplaceSub, acl_src0, acl_src1, alpha); - ACL_CHECK(aclDestroyScalar(alpha)); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSub, acl_src0, acl_src1, alpha); + ggml_cann_release_resources(ctx, alpha); } void aclnn_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_other, aclTensor* acl_dst) { if (acl_dst != nullptr) - GGML_CANN_CALL_ACLNN_OP(Mul, acl_src, acl_other, acl_dst); + GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_src, acl_other, acl_dst); else - GGML_CANN_CALL_ACLNN_OP(InplaceMul, acl_src, acl_other); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_src, acl_other); } void aclnn_div(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_other, aclTensor* acl_dst) { if (acl_dst != nullptr) - GGML_CANN_CALL_ACLNN_OP(Div, acl_src, acl_other, acl_dst); + GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src, acl_other, acl_dst); else - GGML_CANN_CALL_ACLNN_OP(InplaceDiv, acl_src, acl_other); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDiv, acl_src, acl_other); } /** @@ -240,11 +225,11 @@ static void aclnn_muls(ggml_backend_cann_context& ctx, aclTensor* acl_src, float scale, aclTensor* acl_dst, bool inplace) { aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT); if (inplace) { - GGML_CANN_CALL_ACLNN_OP(InplaceMuls, acl_src, acl_scale); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_src, acl_scale); } else { - GGML_CANN_CALL_ACLNN_OP(Muls, acl_src, acl_scale, acl_dst); + GGML_CANN_CALL_ACLNN_OP(ctx, Muls, acl_src, acl_scale, acl_dst); } - ACL_CHECK(aclDestroyScalar(acl_scale)); + ggml_cann_release_resources(ctx, acl_scale); } void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -261,11 +246,8 @@ void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclScalar* acl_negative_slope = aclCreateScalar(&negative_slope, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(LeakyRelu, acl_src, acl_negative_slope, acl_dst); - - ACL_CHECK(aclDestroyScalar(acl_negative_slope)); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + GGML_CANN_CALL_ACLNN_OP(ctx, LeakyRelu, acl_src, acl_negative_slope, acl_dst); + ggml_cann_release_resources(ctx, acl_negative_slope, acl_src, acl_dst); } /** @@ -281,7 +263,7 @@ void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) { static void aclnn_concat(ggml_backend_cann_context& ctx, aclTensorList* tensorList, aclTensor* acl_dst, int64_t concat_dim) { - GGML_CANN_CALL_ACLNN_OP(Cat, tensorList, concat_dim, acl_dst); + GGML_CANN_CALL_ACLNN_OP(ctx, Cat, tensorList, concat_dim, acl_dst); } void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -297,11 +279,10 @@ void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst) { int32_t acl_dim = 3 - dim; aclTensor* tensors[] = {acl_src0, acl_src1}; - aclTensorList* tensorList = aclCreateTensorList(tensors, 2); - aclnn_concat(ctx, tensorList, acl_dst, acl_dim); + aclTensorList* tensor_list = aclCreateTensorList(tensors, 2); + aclnn_concat(ctx, tensor_list, acl_dst, acl_dim); - ACL_CHECK(aclDestroyTensorList(tensorList)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + ggml_cann_release_resources(ctx, tensor_list, acl_dst); } /** @@ -331,10 +312,8 @@ static void aclnn_arange(ggml_backend_cann_context& ctx, aclTensor* acl_dst, aclScalar* acl_end = aclCreateScalar(&stop, aclDataType::ACL_FLOAT); aclScalar* acl_step = aclCreateScalar(&step, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(Arange, acl_start, acl_end, acl_step, acl_dst); - ACL_CHECK(aclDestroyScalar(acl_start)); - ACL_CHECK(aclDestroyScalar(acl_end)); - ACL_CHECK(aclDestroyScalar(acl_step)); + GGML_CANN_CALL_ACLNN_OP(ctx, Arange, acl_start, acl_end, acl_step, acl_dst); + ggml_cann_release_resources(ctx, acl_start, acl_end, acl_step); } void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -351,7 +330,7 @@ void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst) { memcpy(&step, (float*)dst->op_params + 2, sizeof(float)); aclnn_arange(ctx, acl_dst, start, stop, step, n_elements); - ACL_CHECK(aclDestroyTensor(acl_dst)); + ggml_cann_release_resources(ctx, acl_dst); } void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -368,11 +347,8 @@ void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclScalar* acl_min = aclCreateScalar(&min, aclDataType::ACL_FLOAT); aclScalar* acl_max = aclCreateScalar(&max, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(Clamp, acl_src, acl_min, acl_max, acl_dst); - ACL_CHECK(aclDestroyScalar(acl_min)); - ACL_CHECK(aclDestroyScalar(acl_max)); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + GGML_CANN_CALL_ACLNN_OP(ctx, Clamp, acl_src, acl_min, acl_max, acl_dst); + ggml_cann_release_resources(ctx, acl_min, acl_max, acl_src, acl_dst); } void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -386,10 +362,8 @@ void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclTensor* acl_src = ggml_cann_create_tensor(src); aclTensor* acl_dst = ggml_cann_create_tensor(dst); - GGML_CANN_CALL_ACLNN_OP(Muls, acl_src, scale, acl_dst); - ACL_CHECK(aclDestroyScalar(scale)); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + GGML_CANN_CALL_ACLNN_OP(ctx, Muls, acl_src, scale, acl_dst); + ggml_cann_release_resources(ctx, scale, acl_src, acl_dst); } void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -404,12 +378,10 @@ void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclTensor* tmp_tensor = ggml_cann_create_tensor(buffer, ACL_INT64, ggml_type_size(dst->type), dst->ne, dst->nb, GGML_MAX_DIMS); - GGML_CANN_CALL_ACLNN_OP(Argsort, acl_src, -1, (order == GGML_SORT_ORDER_DESC ? true : false), + GGML_CANN_CALL_ACLNN_OP(ctx, Argsort, acl_src, -1, (order == GGML_SORT_ORDER_DESC ? true : false), tmp_tensor); - GGML_CANN_CALL_ACLNN_OP(Cast, tmp_tensor, ggml_cann_type_mapping(dst->type), acl_dst); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(tmp_tensor)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + GGML_CANN_CALL_ACLNN_OP(ctx, Cast, tmp_tensor, ggml_cann_type_mapping(dst->type), acl_dst); + ggml_cann_release_resources(ctx, acl_src, tmp_tensor, acl_dst); } void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -423,11 +395,9 @@ void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { std::vector normData = {dst->ne[0]}; aclIntArray* norm = aclCreateIntArray(normData.data(), normData.size()); - GGML_CANN_CALL_ACLNN_OP(LayerNorm, acl_src, norm, nullptr, nullptr, + GGML_CANN_CALL_ACLNN_OP(ctx, LayerNorm, acl_src, norm, nullptr, nullptr, eps, acl_dst, nullptr, nullptr); - ACL_CHECK(aclDestroyIntArray(norm)); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + ggml_cann_release_resources(ctx, norm, acl_src, acl_dst); } void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -457,12 +427,9 @@ void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclTensor* acl_rstd_out = ggml_cann_create_tensor( (char*)buffer + n_bytes, ACL_FLOAT, type_size, ne, nb, ACL_FORMAT_ND); - GGML_CANN_CALL_ACLNN_OP(GroupNorm, acl_src, nullptr, nullptr, N, C, HxW, n_groups, eps, + GGML_CANN_CALL_ACLNN_OP(ctx, GroupNorm, acl_src, nullptr, nullptr, N, C, HxW, n_groups, eps, acl_dst, acl_mean_out, acl_rstd_out); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); - ACL_CHECK(aclDestroyTensor(acl_mean_out)); - ACL_CHECK(aclDestroyTensor(acl_rstd_out)); + ggml_cann_release_resources(ctx, acl_src, acl_dst, acl_mean_out, acl_rstd_out); } void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -487,19 +454,17 @@ void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst) { if (!inplace) { size_t cpy_size = ggml_nbytes(dst); - ACL_CHECK(aclrtMemcpyAsync(dst->data, cpy_size, src0->data, cpy_size, - ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); + ggml_cann_async_memcpy(ctx, dst->data, src0->data, cpy_size, + ACL_MEMCPY_DEVICE_TO_DEVICE); aclTensor* acl_src0 = ggml_cann_create_tensor( src0, src1->ne, src0->nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset); - GGML_CANN_CALL_ACLNN_OP(Add, acl_src0, acl_src1, alpha, acl_dst); - ACL_CHECK(aclDestroyTensor(acl_src0)); + GGML_CANN_CALL_ACLNN_OP(ctx, Add, acl_src0, acl_src1, alpha, acl_dst); + ggml_cann_release_resources(ctx, acl_src0); } else { - GGML_CANN_CALL_ACLNN_OP(InplaceAdd, acl_dst, acl_src1, alpha); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst, acl_src1, alpha); } - - ACL_CHECK(aclDestroyTensor(acl_src1)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + ggml_cann_release_resources(ctx, acl_src1, acl_dst); } /** @@ -512,7 +477,6 @@ void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst) { * @param dim An array of dimension indices. * @param dim_size The number of dimensions. */ - static void aclnn_reduce_sum(ggml_backend_cann_context& ctx, ggml_tensor* dst, int64_t* dim, size_t dim_size) { GGML_ASSERT(dst->ne[0] == 1); @@ -521,11 +485,9 @@ static void aclnn_reduce_sum(ggml_backend_cann_context& ctx, ggml_tensor* dst, aclTensor* acl_dst = ggml_cann_create_tensor(dst); aclIntArray* reduce_dims = aclCreateIntArray(dim, dim_size); - GGML_CANN_CALL_ACLNN_OP(ReduceSum, acl_src, reduce_dims, true, + GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_src, reduce_dims, true, ggml_cann_type_mapping(dst->type), acl_dst); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); - ACL_CHECK(aclDestroyIntArray(reduce_dims)); + ggml_cann_release_resources(ctx, acl_src, acl_dst, reduce_dims); } void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -549,10 +511,8 @@ void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx, std::vector output_size{dst->ne[1], dst->ne[0]}; auto output_size_array = aclCreateIntArray(output_size.data(), 2); - GGML_CANN_CALL_ACLNN_OP(UpsampleNearest2d, acl_src, output_size_array, acl_dst); - ACL_CHECK(aclDestroyIntArray(output_size_array)); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + GGML_CANN_CALL_ACLNN_OP(ctx, UpsampleNearest2d, acl_src, output_size_array, acl_dst); + ggml_cann_release_resources(ctx, acl_src, acl_dst, output_size_array); } /** @@ -575,9 +535,8 @@ static void aclnn_pad(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclIntArray* acl_pad = aclCreateIntArray(paddings, GGML_MAX_DIMS * 2); aclScalar* acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ConstantPadNd, acl_src, acl_pad, acl_value, acl_dst); - ACL_CHECK(aclDestroyIntArray(acl_pad)); - ACL_CHECK(aclDestroyScalar(acl_value)); + GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_src, acl_pad, acl_value, acl_dst); + ggml_cann_release_resources(ctx, acl_pad, acl_value); } void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -593,9 +552,7 @@ void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst) { 0, dst->ne[0] - src->ne[0], 0, dst->ne[1] - src->ne[1], 0, dst->ne[2] - src->ne[2], 0, dst->ne[3] - src->ne[3]}; aclnn_pad(ctx, acl_src, acl_dst, paddings); - - ACL_CHECK(aclDestroyTensor(acl_dst)); - ACL_CHECK(aclDestroyTensor(acl_src)); + ggml_cann_release_resources(ctx, acl_src, acl_dst); } /** @@ -641,14 +598,15 @@ static void ggml_cann_avg_pool2d(ggml_backend_cann_context& ctx, bool count_include_pad = true; int64_t divisor_override = 0; int8_t cube_math_type = 0; - GGML_CANN_CALL_ACLNN_OP(AvgPool2d, acl_src, kernel_size, strides, paddings_avg, +#ifdef ASCEND_310P + cube_math_type = 1; +#endif + + GGML_CANN_CALL_ACLNN_OP(ctx, AvgPool2d, acl_src, kernel_size, strides, paddings_avg, ceil_mode, count_include_pad, divisor_override, cube_math_type, acl_dst); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); - ACL_CHECK(aclDestroyIntArray(kernel_size)); - ACL_CHECK(aclDestroyIntArray(strides)); - ACL_CHECK(aclDestroyIntArray(paddings_avg)); + ggml_cann_release_resources(ctx, acl_src, acl_dst, kernel_size, strides, + paddings_avg); } /** @@ -716,15 +674,10 @@ static void ggml_cann_max_pool2d(ggml_backend_cann_context& ctx, bool ceil_mode = false; int64_t auto_pads = 0; - GGML_CANN_CALL_ACLNN_OP(MaxPool, tmp_tensor, kernel_size, strides, auto_pads, + GGML_CANN_CALL_ACLNN_OP(ctx, MaxPool, tmp_tensor, kernel_size, strides, auto_pads, paddings_max, dilations, ceil_mode, acl_dst); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); - ACL_CHECK(aclDestroyTensor(tmp_tensor)); - ACL_CHECK(aclDestroyIntArray(kernel_size)); - ACL_CHECK(aclDestroyIntArray(strides)); - ACL_CHECK(aclDestroyIntArray(paddings_max)); - ACL_CHECK(aclDestroyIntArray(dilations)); + ggml_cann_release_resources(ctx, acl_src, acl_dst, tmp_tensor, kernel_size, + strides, paddings_max, dilations); } void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -755,7 +708,7 @@ void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst) { */ static void cann_copy(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) { - GGML_CANN_CALL_ACLNN_OP(InplaceCopy, acl_dst, acl_src); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst, acl_src); } void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -767,15 +720,14 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) { if (dst->type == src0->type) { cann_copy(ctx, acl_src, acl_dst); } else { - aclnn_cast(ctx, acl_src, acl_dst, dst); + aclnn_cast(ctx, acl_src, acl_dst, ggml_cann_type_mapping(dst->type)); } } else { if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) { if (dst->type == src0->type) { size_t cpy_size = ggml_nbytes(dst); - ACL_CHECK(aclrtMemcpyAsync( - dst->data, cpy_size, src0->data, cpy_size, - ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); + ggml_cann_async_memcpy(ctx, dst->data, src0->data, cpy_size, + ACL_MEMCPY_DEVICE_TO_DEVICE); return; } else { ggml_cann_pool_alloc src_buffer_allocator( @@ -792,12 +744,11 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); - aclnn_cast(ctx, acl_src, src_trans_tensor, dst); + aclnn_cast(ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping(dst->type)); size_t cpy_size = ggml_nbytes(dst); - ACL_CHECK(aclrtMemcpyAsync( - dst->data, cpy_size, src_trans_buffer, cpy_size, - ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); - ACL_CHECK(aclDestroyTensor(src_trans_tensor)); + ggml_cann_async_memcpy(ctx, dst->data, src_trans_buffer, cpy_size, + ACL_MEMCPY_DEVICE_TO_DEVICE); + ggml_cann_release_resources(ctx, src_trans_tensor); return; } } else if (ggml_is_contiguous(dst)) { @@ -814,21 +765,18 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); - aclnn_cast(ctx, acl_src, src_trans_tensor, dst); + aclnn_cast(ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping(dst->type)); size_t cpy_size = ggml_nbytes(dst); - ACL_CHECK(aclrtMemcpyAsync(dst->data, cpy_size, src_trans_buffer, - cpy_size, ACL_MEMCPY_DEVICE_TO_DEVICE, - ctx.stream())); - ACL_CHECK(aclDestroyTensor(src_trans_tensor)); + ggml_cann_async_memcpy(ctx, dst->data, src_trans_buffer, cpy_size, + ACL_MEMCPY_DEVICE_TO_DEVICE); + ggml_cann_release_resources(ctx, src_trans_tensor); return; } else { GGML_ABORT("Unsupport dst is not tontiguous."); } } - - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + ggml_cann_release_resources(ctx, acl_src, acl_dst); } /** @@ -856,7 +804,7 @@ static aclTensor* aclnn_zero(ggml_backend_cann_context& ctx, void* buffer, nb[i] = nb[i - 1] * ne[i - 1]; } - ACL_CHECK(aclrtMemsetAsync(buffer, n_bytes, 0, n_bytes, ctx.stream())); + ggml_cann_async_memset(ctx, buffer, n_bytes, 0); aclTensor* zero = ggml_cann_create_tensor(buffer, type, type_size, ne, nb, dims); return zero; @@ -889,7 +837,7 @@ static aclTensor* aclnn_values(ggml_backend_cann_context& ctx, void* buffer, float alpha_host = 1.0f; aclScalar* alpha = aclCreateScalar(&alpha_host, aclDataType::ACL_FLOAT); aclScalar* other = aclCreateScalar(&value, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(InplaceAdds, acl_tensor, other, alpha); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_tensor, other, alpha); return acl_tensor; } @@ -915,11 +863,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclnn_zero(ctx, zero_tensor_allocator.get(), zero_tensor_n_bytes, src->ne, GGML_MAX_DIMS, ggml_cann_type_mapping(src->type), ggml_element_size(src)); - GGML_CANN_CALL_ACLNN_OP(RmsNorm, acl_src, acl_gamma, eps, acl_dst, acl_rstd); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); - ACL_CHECK(aclDestroyTensor(acl_gamma)); - ACL_CHECK(aclDestroyTensor(acl_rstd)); + GGML_CANN_CALL_ACLNN_OP(ctx, RmsNorm, acl_src, acl_gamma, eps, acl_dst, acl_rstd); + ggml_cann_release_resources(ctx, acl_src, acl_dst, acl_gamma, acl_rstd); } // TODO: performace is low. @@ -945,13 +890,10 @@ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst, float alphaValue = 1.0f; alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(InplaceTriu, mask_tensor, n_past + 1); - GGML_CANN_CALL_ACLNN_OP(Tril, acl_src, n_past + 1, acl_dst); - GGML_CANN_CALL_ACLNN_OP(InplaceAdd, acl_dst, mask_tensor, alpha); - ACL_CHECK(aclDestroyScalar(alpha)); - ACL_CHECK(aclDestroyTensor(mask_tensor)); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceTriu, mask_tensor, n_past + 1); + GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_src, n_past + 1, acl_dst); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst, mask_tensor, alpha); + ggml_cann_release_resources(ctx, alpha, acl_src, acl_dst, mask_tensor); } /** @@ -972,7 +914,8 @@ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst, static void aclnn_permute(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst, int64_t* new_dim, uint64_t dims) { aclIntArray* acl_dims = aclCreateIntArray(new_dim, dims); - GGML_CANN_CALL_ACLNN_OP(Permute, acl_src, acl_dims, acl_dst); + GGML_CANN_CALL_ACLNN_OP(ctx, Permute, acl_src, acl_dims, acl_dst); + ggml_cann_release_resources(ctx, acl_dims); } static void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx, @@ -993,8 +936,7 @@ static void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx, aclnn_permute(ctx, tmp_im2col_tensor, acl_dst, permute_dim, 3); } - // release - ACL_CHECK(aclDestroyTensor(acl_dst)); + ggml_cann_release_resources(ctx, acl_dst); } static void ggml_cann_im2col_1d_post_process( @@ -1016,7 +958,6 @@ static void ggml_cann_im2col_1d_post_process( // Permute: [N, IC * KH * KW, OW * OH] -> // [N, OW * OH * n_bytes_factor, IC * KH * KW] - aclTensor* tmp_permute_tensor = nullptr; ggml_cann_pool_alloc tmp_permute_allocator(ctx.pool()); tmp_permute_allocator.alloc(ggml_nbytes(dst) * n_bytes_factor); void* tmp_permute_buffer = tmp_permute_allocator.get(); @@ -1028,7 +969,7 @@ static void ggml_cann_im2col_1d_post_process( tmp_permute_nb[i] = tmp_permute_nb[i - 1] * tmp_permute_ne[i - 1]; } - tmp_permute_tensor = ggml_cann_create_tensor( + aclTensor* tmp_permute_tensor = ggml_cann_create_tensor( tmp_permute_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), tmp_permute_ne, tmp_permute_nb, GGML_MAX_DIMS - 1, ACL_FORMAT_ND); @@ -1058,9 +999,8 @@ static void ggml_cann_im2col_1d_post_process( c * KH * KW * n_step_w * ggml_type_size(dst->type); for (int i = 0; i < n_step_w; i++) { - ACL_CHECK(aclrtMemcpyAsync( - cur_dst_buffer, size_cpy, cur_permute_buffer, size_cpy, - ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); + ggml_cann_async_memcpy(ctx, cur_dst_buffer, cur_permute_buffer, size_cpy, + ACL_MEMCPY_DEVICE_TO_DEVICE); cur_dst_buffer = (char*)cur_dst_buffer + KH * KW * ggml_type_size(dst->type); cur_permute_buffer = (char*)cur_permute_buffer + @@ -1070,13 +1010,11 @@ static void ggml_cann_im2col_1d_post_process( } else { offset = KH * KW * n_step_w * ggml_type_size(dst->type); // equal to ggml_nbytes(dst) - ACL_CHECK(aclrtMemcpyAsync(dst->data, offset, - (char*)tmp_permute_buffer + offset, offset, - ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); + ggml_cann_async_memcpy(ctx, dst->data, (char*)tmp_permute_buffer + offset, offset, + ACL_MEMCPY_DEVICE_TO_DEVICE); } - // release - ACL_CHECK(aclDestroyTensor(tmp_permute_tensor)); + ggml_cann_release_resources(ctx, tmp_permute_tensor); } void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -1138,7 +1076,7 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) { auto* dilations = aclCreateIntArray(dilation_size.data(), 2); auto* paddings = aclCreateIntArray(padding_dims.data(), 2); auto* strides = aclCreateIntArray(stride_dims.data(), 2); - GGML_CANN_CALL_ACLNN_OP(Im2col, acl_src1, kernel_size, dilations, + GGML_CANN_CALL_ACLNN_OP(ctx, Im2col, acl_src1, kernel_size, dilations, paddings, strides, tmp_im2col_tensor); // Cast if dst is f16. @@ -1158,7 +1096,7 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) { tmp_cast_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), tmp_im2col_ne, temp_cast_nb, GGML_MAX_DIMS - 1, ACL_FORMAT_ND); - aclnn_cast(ctx, tmp_im2col_tensor, tmp_cast_tensor, dst); + aclnn_cast(ctx, tmp_im2col_tensor, tmp_cast_tensor, ggml_cann_type_mapping(dst->type)); } // post-processing @@ -1172,14 +1110,8 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) { tmp_im2col_tensor, im2col_op_params); } - // release - ACL_CHECK(aclDestroyTensor(acl_src1)); - ACL_CHECK(aclDestroyTensor(tmp_im2col_tensor)); - ACL_CHECK(aclDestroyTensor(tmp_cast_tensor)); - ACL_CHECK(aclDestroyIntArray(kernel_size)); - ACL_CHECK(aclDestroyIntArray(dilations)); - ACL_CHECK(aclDestroyIntArray(paddings)); - ACL_CHECK(aclDestroyIntArray(strides)); + ggml_cann_release_resources(ctx, acl_src1, tmp_im2col_tensor, tmp_cast_tensor, + kernel_size, dilations, paddings, strides); } /** @@ -1196,17 +1128,17 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) { * @param acl_src The tensor on which the exponential function will be applied. */ static void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src) { - GGML_CANN_CALL_ACLNN_OP(InplaceExp, acl_src); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceExp, acl_src); } void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) { - GGML_CANN_CALL_ACLNN_OP(Cos, acl_src, acl_dst); + GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst); } void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) { - GGML_CANN_CALL_ACLNN_OP(Sin, acl_src, acl_dst); + GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst); } void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, @@ -1255,13 +1187,13 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, ggml_cann_pool_alloc permute_allocator(ctx.pool(), ggml_nbytes(src)); void* tmp_permute_buffer = permute_allocator.get(); - aclTensor* tmp_permute_tenosr = ggml_cann_create_tensor( + aclTensor* tmp_permute_tensor = ggml_cann_create_tensor( tmp_permute_buffer, ggml_cann_type_mapping(src->type), ggml_type_size(src->type), tmp_permute_ne, tmp_permute_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); int64_t permute_dim[] = {0, 1, 3, 2}; int64_t num_dims = 4; - aclnn_permute(ctx, acl_src, tmp_permute_tenosr, permute_dim, num_dims); + aclnn_permute(ctx, acl_src, tmp_permute_tensor, permute_dim, num_dims); // timestep * freq int64_t tmp_mul_ne[] = {src->ne[1] * half, src->ne[0], src->ne[2], @@ -1282,7 +1214,7 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, tmp_mul_buffer, ggml_cann_type_mapping(src->type), ggml_type_size(src->type), tmp_mul_ne, tmp_mul_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); - aclnn_mul(ctx, tmp_permute_tenosr, tmp_arange_tensor, tmp_mul_tensor); + aclnn_mul(ctx, tmp_permute_tensor, tmp_arange_tensor, tmp_mul_tensor); // cos ggml_cann_pool_alloc cos_allocator( @@ -1310,17 +1242,13 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, int64_t concat_dim = 3; aclTensor* acl_dst = ggml_cann_create_tensor(dst); aclTensor* tensors[] = {tmp_cos_tensor, tmp_sin_tensor}; - aclTensorList* tensorList = aclCreateTensorList(tensors, 2); - aclnn_concat(ctx, tensorList, acl_dst, concat_dim); + aclTensorList* tensor_list = aclCreateTensorList(tensors, 2); + aclnn_concat(ctx, tensor_list, acl_dst, concat_dim); // release // segmentation fault when delete both tensorList and his elements. - ACL_CHECK(aclDestroyTensorList(tensorList)); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(tmp_arange_tensor)); - ACL_CHECK(aclDestroyTensor(tmp_permute_tenosr)); - ACL_CHECK(aclDestroyTensor(tmp_mul_tensor)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + ggml_cann_release_resources(ctx, tensor_list, acl_src, tmp_arange_tensor, + tmp_permute_tensor, tmp_mul_tensor, acl_dst); } /** @@ -1336,8 +1264,8 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar, aclTensor* acl_dst) { auto acl_scalar = aclCreateScalar(&scalar, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(InplaceFillScalar, acl_dst, acl_scalar); - ACL_CHECK(aclDestroyScalar(acl_scalar)); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst, acl_scalar); + ggml_cann_release_resources(ctx, acl_scalar); } /** @@ -1358,7 +1286,7 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar, */ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx, aclTensor* acl_dst, aclTensor* acl_exp) { - GGML_CANN_CALL_ACLNN_OP(InplacePowTensorTensor, acl_dst, acl_exp); + GGML_CANN_CALL_ACLNN_OP(ctx, InplacePowTensorTensor, acl_dst, acl_exp); } /** @@ -1510,15 +1438,9 @@ static void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src, // add aclnn_add(ctx, tmp_output_tensor, acl_src, acl_dst); - - ACL_CHECK(aclDestroyTensor(tmp_arange1_tensor)); - ACL_CHECK(aclDestroyTensor(tmp_arange2_tensor)); - ACL_CHECK(aclDestroyTensor(tmp_mk_base1_tensor)); - ACL_CHECK(aclDestroyTensor(tmp_mk_base2_tensor)); - ACL_CHECK(aclDestroyTensor(tmp_mk_base_tensor)); - ACL_CHECK(aclDestroyTensor(tmp_arange_tensor)); - ACL_CHECK(aclDestroyTensor(tmp_mk_tensor)); - ACL_CHECK(aclDestroyTensor(tmp_output_tensor)); + ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor, + tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor, + tmp_arange_tensor, tmp_mk_tensor, tmp_output_tensor); } void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -1541,7 +1463,7 @@ void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) { */ static void aclnn_softmax(ggml_backend_cann_context& ctx, aclTensor* acl_src, int64_t dim, aclTensor* acl_dst) { - GGML_CANN_CALL_ACLNN_OP(Softmax, acl_src, dim, acl_dst); + GGML_CANN_CALL_ACLNN_OP(ctx, Softmax, acl_src, dim, acl_dst); } void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) { @@ -1591,8 +1513,7 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) { src1_fp32_nb, GGML_MAX_DIMS); aclTensor* acl_src1 = ggml_cann_create_tensor(src1); aclnn_cast(ctx, acl_src1, acl_src1_fp32_tensor, ACL_FLOAT); - - ACL_CHECK(aclDestroyTensor(acl_src1)); + ggml_cann_release_resources(ctx, acl_src1); } else { acl_src1_fp32_tensor = ggml_cann_create_tensor(src1); } @@ -1645,17 +1566,13 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) { // softmax aclnn_softmax(ctx, alibi_output_tensor, 3, acl_dst); - ACL_CHECK(aclDestroyTensor(alibi_output_tensor)); + ggml_cann_release_resources(ctx, alibi_output_tensor); } else { aclnn_softmax(ctx, acl_input_mul_scale_tensor, 3, acl_dst); } - ACL_CHECK(aclDestroyTensor(acl_src0)); - ACL_CHECK(aclDestroyTensor(acl_src1_fp32_tensor)); - ACL_CHECK(aclDestroyTensor(acl_dst)); - ACL_CHECK(aclDestroyScalar(acl_scale)); - ACL_CHECK(aclDestroyTensor(acl_input_mul_scale_tensor)); - ACL_CHECK(aclDestroyTensor(tmp_mask_tensor)); + ggml_cann_release_resources(ctx, acl_src0, acl_src1_fp32_tensor, acl_dst, + acl_scale, acl_input_mul_scale_tensor, tmp_mask_tensor); } /** @@ -1702,10 +1619,8 @@ static void aclnn_embedding_4d(ggml_backend_cann_context& ctx, void* src_buffer, (char*)dst->data + i * dst->nb[3] + j * dst->nb[2], ggml_cann_type_mapping(dst->type), ggml_element_size(dst), acl_out_ne, acl_out_nb, 2); - GGML_CANN_CALL_ACLNN_OP(Embedding, acl_src_tensor, acl_index, acl_out); - ACL_CHECK(aclDestroyTensor(acl_src_tensor)); - ACL_CHECK(aclDestroyTensor(acl_index)); - ACL_CHECK(aclDestroyTensor(acl_out)); + GGML_CANN_CALL_ACLNN_OP(ctx, Embedding, acl_src_tensor, acl_index, acl_out); + ggml_cann_release_resources(ctx, acl_src_tensor, acl_index, acl_out); } } } @@ -1733,11 +1648,10 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclTensor* src_trans_tensor = ggml_cann_create_tensor( src_trans_buffer, ACL_FLOAT, ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); - aclnn_cast(ctx, acl_src0, src_trans_tensor, dst); + aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type)); aclnn_embedding_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, src1, dst); - ACL_CHECK(aclDestroyTensor(acl_src0)); - ACL_CHECK(aclDestroyTensor(src_trans_tensor)); + ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor); break; } case GGML_TYPE_Q8_0: { @@ -1783,7 +1697,7 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb, GGML_MAX_DIMS + 1); aclTensor* acl_scale_tensor = ggml_cann_create_tensor( - src0->data, ACL_FLOAT16, sizeof(float16_t), scale_ne, scale_nb, + src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb, GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset); aclTensor* dequant_tensor = ggml_cann_create_tensor( dequant_buffer_allocator.get(), ACL_FLOAT, sizeof(float_t), @@ -1799,7 +1713,7 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclnn_embedding_4d(ctx, dequant_buffer_allocator.get(), dequant_ne, dequant_nb, src1, dst); - ACL_CHECK(aclDestroyTensor(dequant_tensor)); + ggml_cann_release_resources(ctx, dequant_tensor); break; } default: @@ -1827,7 +1741,7 @@ static void aclnn_repeat_interleave(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst, int64_t dim, int64_t repeats, int64_t output_size) { - GGML_CANN_CALL_ACLNN_OP(RepeatInterleaveIntWithDim, acl_src, repeats, dim, + GGML_CANN_CALL_ACLNN_OP(ctx, RepeatInterleaveIntWithDim, acl_src, repeats, dim, output_size, acl_dst); } @@ -1876,21 +1790,19 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx, switch (n_dims) { case 2: - GGML_CANN_CALL_ACLNN_OP(Mm, acl_input_tensor, acl_weight_tensor, acl_dst, 2); + GGML_CANN_CALL_ACLNN_OP(ctx, Mm, acl_input_tensor, acl_weight_tensor, acl_dst, 2); break; case 3: - GGML_CANN_CALL_ACLNN_OP(BatchMatMul, acl_input_tensor, acl_weight_tensor, acl_dst, 2); + GGML_CANN_CALL_ACLNN_OP(ctx, BatchMatMul, acl_input_tensor, acl_weight_tensor, acl_dst, 2); break; default: // ALLOW_FP32_DOWN_PRECISION, when input is // fp32, atlas a2 will transpose it to HFLOAT32. - GGML_CANN_CALL_ACLNN_OP(Matmul, acl_input_tensor, acl_weight_tensor, acl_dst, 1); + GGML_CANN_CALL_ACLNN_OP(ctx, Matmul, acl_input_tensor, acl_weight_tensor, acl_dst, 1); break; } - ACL_CHECK(aclDestroyTensor(acl_weight_tensor)); - ACL_CHECK(aclDestroyTensor(acl_input_tensor)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + ggml_cann_release_resources(ctx, acl_weight_tensor, acl_input_tensor, acl_dst); } /** @@ -1960,9 +1872,7 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx, input_buffer, ACL_FLOAT16, input_elem_size, input_cast_ne, input_cast_nb, GGML_MAX_DIMS); aclnn_cast(ctx, acl_src1_tensor, acl_input_tensor, ACL_FLOAT16); - - ACL_CHECK(aclDestroyTensor(acl_input_tensor)); - ACL_CHECK(aclDestroyTensor(acl_src1_tensor)); + ggml_cann_release_resources(ctx, acl_input_tensor, acl_src1_tensor); } // output @@ -2015,13 +1925,11 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx, if (src0->ne[0] > QK8_0) { antiquantGroupSize = QK8_0; } - GGML_CANN_CALL_ACLNN_OP(WeightQuantBatchMatmulV2, acl_input_tensor, + GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, acl_input_tensor, acl_weight_tensor, acl_scale_tensor, nullptr, nullptr, nullptr, nullptr, antiquantGroupSize, acl_output_tensor); - ACL_CHECK(aclDestroyTensor(acl_weight_tensor)); - ACL_CHECK(aclDestroyTensor(acl_scale_tensor)); - ACL_CHECK(aclDestroyTensor(acl_output_tensor)); + ggml_cann_release_resources(ctx, acl_weight_tensor, acl_scale_tensor, acl_output_tensor); // other splits for (int64_t split = 1; split < split_size; split++) { @@ -2048,16 +1956,14 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx, (char*)output_buffer + batch1 * output_stride, ACL_FLOAT16, output_elem_size, output_ne, output_nb, 2, ACL_FORMAT_ND, output_ne_offset); - GGML_CANN_CALL_ACLNN_OP(WeightQuantBatchMatmulV2, acl_input_tensor, + GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, acl_input_tensor, acl_weight_tensor, acl_scale_tensor, nullptr, nullptr, nullptr, nullptr, antiquantGroupSize, acl_output_tensor); - ACL_CHECK(aclDestroyTensor(acl_weight_tensor)); - ACL_CHECK(aclDestroyTensor(acl_scale_tensor)); - ACL_CHECK(aclDestroyTensor(acl_output_tensor)); + ggml_cann_release_resources(ctx, acl_weight_tensor, acl_scale_tensor, acl_output_tensor); } - ACL_CHECK(aclDestroyTensor(acl_input_tensor)); + ggml_cann_release_resources(ctx, acl_input_tensor); } } @@ -2074,10 +1980,9 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx, output_buffer, ACL_FLOAT16, output_elem_size, output_cast_ne, output_cast_nb, GGML_MAX_DIMS); aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); - aclnn_cast(ctx, acl_output_tensor, acl_dst_tensor, dst); + aclnn_cast(ctx, acl_output_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type)); - ACL_CHECK(aclDestroyTensor(acl_output_tensor)); - ACL_CHECK(aclDestroyTensor(acl_dst_tensor)); + ggml_cann_release_resources(ctx, acl_output_tensor, acl_dst_tensor); } } @@ -2118,9 +2023,8 @@ static void aclnn_roll(ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst, int64_t* shifts, int64_t* dims) { aclIntArray* acl_shifts = aclCreateIntArray(shifts, 1); aclIntArray* acl_dims = aclCreateIntArray(dims, 1); - GGML_CANN_CALL_ACLNN_OP(Roll, acl_src, acl_shifts, acl_dims, acl_dst); - ACL_CHECK(aclDestroyIntArray(acl_shifts)); - ACL_CHECK(aclDestroyIntArray(acl_dims)); + GGML_CANN_CALL_ACLNN_OP(ctx, Roll, acl_src, acl_shifts, acl_dims, acl_dst); + ggml_cann_release_resources(ctx, acl_shifts, acl_dims); } /** @@ -2142,9 +2046,8 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx, float value) { aclIntArray* acl_index = aclCreateIntArray(index, index_num); aclScalar* acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(InplaceIndexFillTensor, acl_src, dim, acl_index, acl_value); - ACL_CHECK(aclDestroyIntArray(acl_index)); - ACL_CHECK(aclDestroyScalar(acl_value)); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexFillTensor, acl_src, dim, acl_index, acl_value); + ggml_cann_release_resources(ctx, acl_index, acl_value); } static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, @@ -2159,37 +2062,30 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, ggml_tensor* src1 = dst->src[1]; // position ggml_tensor* src2 = dst->src[2]; // freq_factors - // arange, [0,1,...,ne0/2] - int64_t arange_length = src0->ne[0] / 2; - ggml_cann_pool_alloc arange_allocator(ctx.pool(), - arange_length * sizeof(float_t)); - void* arange_buffer = arange_allocator.get(); - int64_t arange_ne[] = {arange_length, 1, 1, 1}; - size_t arange_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t), - arange_length * sizeof(float_t)}; + GGML_TENSOR_BINARY_OP_LOCALS - aclTensor* acl_arange_tensor = - ggml_cann_create_tensor(arange_buffer, ACL_FLOAT, sizeof(float_t), - arange_ne, arange_nb, GGML_MAX_DIMS); + // theta_scale arange, [0,1,...,ne00/2 - 1] + int64_t theta_scale_length = ne00 / 2; + ggml_cann_pool_alloc theta_scale_allocator(ctx.pool(), + theta_scale_length * sizeof(float_t)); + void* theta_scale_buffer = theta_scale_allocator.get(); + int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1}; + size_t theta_scale_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t), + theta_scale_length * sizeof(float_t)}; + + aclTensor* acl_theta_scale_tensor = + ggml_cann_create_tensor(theta_scale_buffer, ACL_FLOAT, sizeof(float_t), + theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); float start = 0; float step = 1; - float stop = src0->ne[0] / 2; - float n_elements = src0->ne[0] / 2; - aclnn_arange(ctx, acl_arange_tensor, start, stop, step, n_elements); + float stop = ne00 / 2; + float n_elements = ne00 / 2; + aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements); // power - // aclnnPowScalarTensor(): @param self is tensor which should be scalar, so - // use aclnn_pow_tensor_tensor() until fixed. aclScalar* acl_theta_scale = - // aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT); - // aclnn_power_scalar_tensor(ctx, acl_theta_scale, acl_arange_tensor, - // acl_power_tensor); - ggml_cann_pool_alloc theta_scale_allocator(ctx.pool(), - arange_length * sizeof(float_t)); - void* theta_scale_buffer = theta_scale_allocator.get(); - aclTensor* acl_theta_scale_tensor = aclnn_values( - ctx, theta_scale_buffer, arange_length * sizeof(float_t), arange_ne, - GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), theta_scale); - aclnn_pow_tensor_tensor(ctx, acl_theta_scale_tensor, acl_arange_tensor); + aclScalar* acl_theta_scale = aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor, + acl_theta_scale_tensor); // freq_scale if (freq_scale != 1) { @@ -2200,28 +2096,27 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, if (src2) { aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor( src2->data, ggml_cann_type_mapping(src2->type), - ggml_type_size(src2->type), arange_ne, arange_nb, GGML_MAX_DIMS); + ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor); - ACL_CHECK(aclDestroyTensor(acl_freq_factors_tensor)); + ggml_cann_release_resources(ctx, acl_freq_factors_tensor); } // position GGML_ASSERT(src1->type == GGML_TYPE_I32); int64_t position_length = src1->ne[0]; - int64_t position_ne[] = {1, position_length, 1, 1}; - size_t position_nb[] = {sizeof(int32_t), sizeof(int32_t), - sizeof(int32_t) * position_length, + int64_t position_ne[] = {1, 1, position_length, 1}; + size_t position_nb[] = {sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) * position_length}; aclTensor* acl_position_tensor = ggml_cann_create_tensor( src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS); // power * position - int64_t theta_length = arange_length * position_length; + int64_t theta_length = theta_scale_length * position_length; ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float_t)); void* theta_buffer = theta_allocator.get(); - int64_t theta_ne[] = {arange_length, position_length, 1, 1}; + int64_t theta_ne[] = {theta_scale_length, 1, position_length, 1}; size_t theta_nb[GGML_MAX_DIMS]; theta_nb[0] = sizeof(float_t); for (int i = 1; i < GGML_MAX_DIMS; i++) { @@ -2233,40 +2128,22 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor, acl_theta_tensor); - // permute: [0,1,2,3]->[0,2,1,3] - int64_t permute_ne[] = {arange_length, 1, position_length, 1}; - size_t permute_nb[GGML_MAX_DIMS]; - permute_nb[0] = sizeof(float_t); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - permute_nb[i] = permute_nb[i - 1] * permute_ne[i - 1]; - } - ggml_cann_pool_alloc permute_allocator(ctx.pool(), - theta_length * sizeof(float_t)); - void* permute_buffer = permute_allocator.get(); - aclTensor* acl_permute_tensor = ggml_cann_create_tensor( - permute_buffer, ACL_FLOAT, sizeof(float_t), permute_ne, permute_nb, - GGML_MAX_DIMS, ACL_FORMAT_ND); - int64_t permute_dim[] = {0, 2, 1, 3}; - int64_t num_dims = 4; - aclnn_permute(ctx, acl_theta_tensor, acl_permute_tensor, permute_dim, - num_dims); - // sin/cos ggml_cann_pool_alloc sin_allocator(ctx.pool(), theta_length * sizeof(float_t)); void* sin_buffer = sin_allocator.get(); aclTensor* acl_sin_tensor = ggml_cann_create_tensor( - sin_buffer, ACL_FLOAT, sizeof(float_t), permute_ne, permute_nb, + sin_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); - aclnn_sin(ctx, acl_permute_tensor, acl_sin_tensor); + aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor); ggml_cann_pool_alloc cos_allocator(ctx.pool(), theta_length * sizeof(float_t)); void* cos_buffer = cos_allocator.get(); aclTensor* acl_cos_tensor = ggml_cann_create_tensor( - cos_buffer, ACL_FLOAT, sizeof(float_t), permute_ne, permute_nb, + cos_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND); - aclnn_cos(ctx, acl_permute_tensor, acl_cos_tensor); + aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor); // attn_factor if (attn_factor != 1) { @@ -2282,7 +2159,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, } else { int64_t num_repeats = 2; int64_t dim = 3; - int64_t output_size = arange_length * num_repeats; + int64_t output_size = theta_scale_length * num_repeats; aclnn_repeat_interleave(ctx, acl_sin_tensor, acl_sin_repeat_tensor, dim, num_repeats, output_size); aclnn_repeat_interleave(ctx, acl_cos_tensor, acl_cos_repeat_tensor, dim, @@ -2290,13 +2167,8 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, } // release - ACL_CHECK(aclDestroyTensor(acl_arange_tensor)); - ACL_CHECK(aclDestroyTensor(acl_theta_scale_tensor)); - ACL_CHECK(aclDestroyTensor(acl_position_tensor)); - ACL_CHECK(aclDestroyTensor(acl_theta_tensor)); - ACL_CHECK(aclDestroyTensor(acl_permute_tensor)); - ACL_CHECK(aclDestroyTensor(acl_sin_tensor)); - ACL_CHECK(aclDestroyTensor(acl_cos_tensor)); + ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor, + acl_theta_tensor, acl_sin_tensor, acl_cos_tensor, acl_theta_scale); } #ifdef __cplusplus @@ -2318,7 +2190,6 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { // TODO: use ascendc // Only test with LLAMA model. ggml_tensor* src0 = dst->src[0]; // input - // ggml_tensor* src2 = dst->src[2]; // freq_factors, not used now. // param float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; @@ -2353,13 +2224,13 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { // init cos/sin cache ggml_cann_pool_alloc sin_allocator( - ctx.pool(), src0->ne[0] * src0->ne[2] * sizeof(float_t)); + ctx.pool(), ne00 * ne02 * sizeof(float_t)); ggml_cann_pool_alloc cos_allocator( - ctx.pool(), src0->ne[0] * src0->ne[2] * sizeof(float_t)); + ctx.pool(), ne00 * ne02 * sizeof(float_t)); void* sin_buffer = sin_allocator.get(); void* cos_buffer = cos_allocator.get(); - int64_t sin_reshape_ne[4] = {src0->ne[0], 1, src0->ne[2], 1}; + int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1}; size_t sin_reshape_nb[GGML_MAX_DIMS]; sin_reshape_nb[0] = sizeof(float_t); for (int i = 1; i < GGML_MAX_DIMS; i++) { @@ -2372,7 +2243,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float_t), sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); aclnn_cache_init(ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor, - theta_scale, freq_scale, attn_factor, is_neox); + theta_scale, freq_scale, attn_factor, is_neox); aclTensor* acl_src = ggml_cann_create_tensor(src0); aclTensor* acl_dst = ggml_cann_create_tensor(dst); @@ -2409,8 +2280,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { int64_t shifts[] = {1}; int64_t dims[] = {3}; aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims); - ACL_CHECK(aclDestroyTensor(acl_input_roll_tensor)); - ACL_CHECK(aclDestroyTensor(acl_input_tensor)); + ggml_cann_release_resources(ctx, acl_input_roll_tensor, acl_input_tensor); // init [-1, 1, -1, 1, ...] minus_one_scale_buffer = minus_one_scale_allocator.get(); @@ -2446,8 +2316,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { int64_t dims[] = {3}; aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims); - ACL_CHECK(aclDestroyTensor(acl_input_roll_tensor)); - ACL_CHECK(aclDestroyTensor(acl_input_tensor)); + ggml_cann_release_resources(ctx, acl_input_roll_tensor, acl_input_tensor); // init [-1, -1, -1, 1, 1,1,...] minus_one_scale_buffer = minus_one_scale_allocator.get(); int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1}; @@ -2472,7 +2341,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { bool inplace = true; float scale = -1; aclnn_muls(ctx, acl_first_half_tensor, scale, nullptr, inplace); - ACL_CHECK(aclDestroyTensor(acl_first_half_tensor)); + ggml_cann_release_resources(ctx, acl_first_half_tensor); } // TODO: n_dims < ne0 @@ -2537,62 +2406,62 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { output_fp32_tensor); aclnn_cast(ctx, output_fp32_tensor, acl_dst, ACL_FLOAT16); - ACL_CHECK(aclDestroyTensor(input_fp32_tensor1)); - ACL_CHECK(aclDestroyTensor(input_fp32_tensor2)); - ACL_CHECK(aclDestroyTensor(output_fp32_tensor)); - ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor)); - ACL_CHECK(aclDestroyTensor(acl_minus_one_tensor)); - ACL_CHECK(aclDestroyTensor(acl_input_roll_mul_scale_tensor)); - ACL_CHECK(aclDestroyTensor(acl_input_roll_reshape_tensor)); - ACL_CHECK(aclDestroyTensor(acl_src)); + ggml_cann_release_resources(ctx, input_fp32_tensor1, input_fp32_tensor2, + output_fp32_tensor, acl_sin_reshape_tensor, + acl_minus_one_tensor, acl_input_roll_mul_scale_tensor, + acl_input_roll_reshape_tensor, acl_src); } return; #endif - // src0 == GGML_TYPE_F16 - // TODO: optimization this `if` code - if (src0->type == GGML_TYPE_F16) { - ggml_cann_pool_alloc sin_final_allocator( - ctx.pool(), src0->ne[0] * src0->ne[2] * ggml_type_size(src0->type)); - ggml_cann_pool_alloc cos_final_allocator( - ctx.pool(), src0->ne[0] * src0->ne[2] * ggml_type_size(src0->type)); - void* sin_final_buffer = sin_final_allocator.get(); - void* cos_final_buffer = cos_final_allocator.get(); + // ggml_mode = 0 --> aclnn_model = 1 + int64_t acl_mode = mode == 0 ? 1 : mode; - int64_t sin_final_ne[4] = {src0->ne[0], 1, src0->ne[2], 1}; - size_t sin_final_nb[GGML_MAX_DIMS]; - sin_final_nb[0] = ggml_type_size(src0->type); - for (int i = 1; i < GGML_MAX_DIMS; i++) { - sin_final_nb[i] = sin_final_nb[i - 1] * sin_final_ne[i - 1]; + switch (src0->type) { + case GGML_TYPE_F32: { + GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src, + acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_mode, acl_dst); + break; } - aclTensor* acl_sin_final_tensor = ggml_cann_create_tensor( - sin_final_buffer, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), sin_final_ne, sin_final_nb, - GGML_MAX_DIMS); - aclTensor* acl_cos_final_tensor = ggml_cann_create_tensor( - cos_final_buffer, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), sin_final_ne, sin_final_nb, - GGML_MAX_DIMS); + case GGML_TYPE_F16: { + ggml_cann_pool_alloc src_trans_allocator( + ctx.pool(), ggml_nelements(src0) * sizeof(float)); + void* src_trans_buffer = src_trans_allocator.get(); + ggml_cann_pool_alloc dst_trans_allocator( + ctx.pool(), ggml_nelements(dst) * sizeof(float)); + void* dst_trans_buffer = dst_trans_allocator.get(); - aclnn_cast(ctx, acl_sin_reshape_tensor, acl_sin_final_tensor, dst); - aclnn_cast(ctx, acl_cos_reshape_tensor, acl_cos_final_tensor, dst); - ACL_CHECK(aclDestroyTensor(acl_cos_reshape_tensor)); - ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor)); - acl_sin_reshape_tensor = acl_sin_final_tensor; - acl_cos_reshape_tensor = acl_cos_final_tensor; + size_t src_trans_nb[GGML_MAX_DIMS]; + src_trans_nb[0] = sizeof(float); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; + } + + aclTensor* acl_src_trans_tensor = ggml_cann_create_tensor( + src_trans_buffer, ACL_FLOAT, sizeof(float), src0->ne, src_trans_nb, + GGML_MAX_DIMS); + aclTensor* acl_dst_trans_tensor = ggml_cann_create_tensor( + dst_trans_buffer, ACL_FLOAT, sizeof(float), dst->ne, src_trans_nb, + GGML_MAX_DIMS); + + aclnn_cast(ctx, acl_src, acl_src_trans_tensor, ACL_FLOAT); + + GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor, + acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_mode, + acl_dst_trans_tensor); + + aclnn_cast(ctx, acl_dst_trans_tensor, acl_dst, ACL_FLOAT16); + + ggml_cann_release_resources(ctx, acl_src_trans_tensor, + acl_dst_trans_tensor); + break; + } + default: + GGML_ABORT("Unsupported tensor type for GGML_OP_ROPE"); + break; } - - int acl_mode = mode; - if (mode == 0) { - acl_mode = 1; - } - - GGML_CANN_CALL_ACLNN_OP(RotaryPositionEmbedding, acl_src, acl_cos_reshape_tensor, - acl_sin_reshape_tensor, acl_mode, acl_dst); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_cos_reshape_tensor)); - ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + ggml_cann_release_resources(ctx, acl_cos_reshape_tensor, + acl_sin_reshape_tensor, acl_src, acl_dst); } @@ -2602,10 +2471,9 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { aclTensor* acl_src = ggml_cann_create_tensor(src0); aclTensor* acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3); - GGML_CANN_CALL_ACLNN_OP(ArgMax, acl_src, 3, false, acl_dst); + GGML_CANN_CALL_ACLNN_OP(ctx, ArgMax, acl_src, 3, false, acl_dst); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + ggml_cann_release_resources(ctx, acl_src, acl_dst); } void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){ @@ -2630,14 +2498,14 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds int64_t groups = 1; int8_t cubeMathType = 0; - GGML_CANN_CALL_ACLNN_OP(Convolution, acl_input, acl_weight, nullptr, stride, +#ifdef ASCEND_310P + cubeMathType = 1; +#endif + + GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input, acl_weight, nullptr, stride, padding, dilation, transposed, padding, groups, acl_dst, cubeMathType); - ACL_CHECK(aclDestroyTensor(acl_weight)); - ACL_CHECK(aclDestroyTensor(acl_dst)); - ACL_CHECK(aclDestroyIntArray(stride)); - ACL_CHECK(aclDestroyIntArray(padding)); - ACL_CHECK(aclDestroyIntArray(dilation)); + ggml_cann_release_resources(ctx, acl_weight, acl_dst, stride, padding, dilation); } void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst){ @@ -2650,12 +2518,10 @@ void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst){ aclScalar* alpha = nullptr; alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(Elu, acl_input, alpha, alpha, alpha, + GGML_CANN_CALL_ACLNN_OP(ctx, Elu, acl_input, alpha, alpha, alpha, acl_dst); - ACL_CHECK(aclDestroyTensor(acl_input)); - ACL_CHECK(aclDestroyTensor(acl_dst)); - ACL_CHECK(aclDestroyScalar(alpha)); + ggml_cann_release_resources(ctx, acl_input, acl_dst, alpha); } void ggml_cann_mean(ggml_backend_cann_context& ctx, ggml_tensor* dst){ @@ -2668,11 +2534,9 @@ void ggml_cann_mean(ggml_backend_cann_context& ctx, ggml_tensor* dst){ aclIntArray* reduceDim = aclCreateIntArray(reduceDimValue, 1); bool keepDim = true; - GGML_CANN_CALL_ACLNN_OP(Mean, acl_src, reduceDim, keepDim, ACL_FLOAT, acl_dst); + GGML_CANN_CALL_ACLNN_OP(ctx, Mean, acl_src, reduceDim, keepDim, ACL_FLOAT, acl_dst); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); - ACL_CHECK(aclDestroyIntArray(reduceDim)); + ggml_cann_release_resources(ctx, acl_src, acl_dst, reduceDim); } void ggml_cann_pad_reflect_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){ @@ -2692,12 +2556,11 @@ void ggml_cann_pad_reflect_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){ ggml_cann_type_mapping(dst->type), ggml_element_size(dst), dst->ne, dst->nb, 3); - GGML_CANN_CALL_ACLNN_OP(ReflectionPad1d, acl_src, paddings, acl_dst); + GGML_CANN_CALL_ACLNN_OP(ctx, ReflectionPad1d, acl_src, paddings, acl_dst); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + ggml_cann_release_resources(ctx, acl_src, acl_dst); } - ACL_CHECK(aclDestroyIntArray(paddings)); + ggml_cann_release_resources(ctx, paddings); } void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst){ @@ -2707,12 +2570,11 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst){ aclTensor* acl_self = ggml_cann_create_tensor(src0); aclTensor* acl_other = ggml_cann_create_tensor(src1); - GGML_CANN_CALL_ACLNN_OP(InplaceEqTensor, acl_self, acl_other); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceEqTensor, acl_self, acl_other); ggml_cann_sum(ctx, dst); - ACL_CHECK(aclDestroyTensor(acl_self)); - ACL_CHECK(aclDestroyTensor(acl_other)); + ggml_cann_release_resources(ctx, acl_self, acl_other); } void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){ @@ -2725,9 +2587,607 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){ aclScalar* alpha = nullptr; alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(GtScalar, acl_src, alpha, acl_dst); + GGML_CANN_CALL_ACLNN_OP(ctx, GtScalar, acl_src, alpha, acl_dst); - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); - ACL_CHECK(aclDestroyScalar(alpha)); + ggml_cann_release_resources(ctx, acl_src, acl_dst, alpha); +} + +/** + * @brief Performs expert-specific matrix multiplication (MoE) with + * floating-point precision using the CANN backend. + * + * This function executes a matrix multiplication operation tailored for + * Mixture of Experts (MoE) models, where the input tensor is multiplied + * with expert-specific weight matrices. It uses the CANN backend for + * efficient computation and stores the result in the destination tensor `dst`. + * The operation may leverage identity-based optimizations or routing masks + * as part of sparse expert selection. + * + * @param ctx The context for executing CANN backend operations. + * @param dst The destination tensor where the MoE multiplication result + * will be stored. + * + * @note This function assumes floating-point data types and is designed for + * MoE architectures, possibly involving sparse expert routing. + */ +static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + //dst [M, K, N, 1] + ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1] + ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1 + ggml_tensor * ids = dst->src[2]; //ids [K, N] + + GGML_TENSOR_BINARY_OP_LOCALS + + // copy index from npu to cpu + int64_t n_as = ne02; // A + int64_t n_ids = ids->ne[0]; // K + + std::vector ids_host(ggml_nbytes(ids)); + ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids), + ACL_MEMCPY_DEVICE_TO_HOST); + ACL_CHECK(aclrtSynchronizeStream(ctx.stream())); + + char * src0_original = (char *) src0->data; + char * src1_original = (char *) src1->data; + char * dst_original = (char *) dst->data; + size_t ori_src0_nb[4] = {nb00, nb01, nb02, nb03}; + + // src0 is F16, src1 is F32, dst is F32 + ggml_cann_pool_alloc src0_cast_allocator; + if (src0->type == GGML_TYPE_F16) { + src0_cast_allocator.alloc(ctx.pool(), sizeof(float) * ggml_nelements(src0)); + void* src0_cast_buf = src0_cast_allocator.get(); + + size_t cast_nb[GGML_MAX_DIMS]; + cast_nb[0] = sizeof(float_t); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + cast_nb[i] = cast_nb[i - 1] * src0->ne[i - 1]; + } + + aclTensor* acl_src0_f16 = ggml_cann_create_tensor(src0); + aclTensor* acl_cast = ggml_cann_create_tensor(src0_cast_buf, + ACL_FLOAT, sizeof(float), src0->ne, cast_nb, 4); + GGML_CANN_CALL_ACLNN_OP(ctx, Cast, acl_src0_f16, ACL_FLOAT, acl_cast); + ggml_cann_release_resources(ctx, acl_cast, acl_src0_f16); + + src0_original = (char *) src0_cast_buf; + memcpy(ori_src0_nb, cast_nb, sizeof(ori_src0_nb)); + } + + std::vector src0_tensor_vec; + std::vector src1_tensor_vec; + std::vector dst_tensor_vec; + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + // src0_row [M, D] -> weight && permute + int64_t src0_ne[2] = {ne01, ne00}; + size_t src0_nb[2] = {ori_src0_nb[1], ori_src0_nb[0]}; + // src1_row [D, 1] -> input + int64_t src1_ne[2] = {ne10, 1}; + size_t src1_nb[2] = {nb10, nb11}; + // dst_row [M, 1] -> out + int64_t dst_ne[2] = {ne0, 1}; + size_t dst_nb[2] = {nb0, nb1}; + + // expert index + int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + // If B = 1 (broadcast), always use 0; otherwise, use id. + int64_t i11 = (ne11 == 1 ? 0 : id); + int64_t i12 = iid1; + + int64_t i1 = id; + int64_t i2 = i12; + + void* src0_tmp_ptr = src0_original + i02*ori_src0_nb[2]; + void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12; + void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2; + + aclTensor* acl_src0 = ggml_cann_create_tensor(src0_tmp_ptr, + ACL_FLOAT, sizeof(float), + src0_ne, src0_nb, 2); + aclTensor* acl_src1 = ggml_cann_create_tensor(src1_tmp_ptr, + ACL_FLOAT, sizeof(float), + src1_ne, src1_nb, 2); + aclTensor* acl_dst = ggml_cann_create_tensor(dst_tmp_ptr, + ACL_FLOAT, sizeof(float), + dst_ne, dst_nb, 2); + + src0_tensor_vec.push_back(acl_src0); + src1_tensor_vec.push_back(acl_src1); + dst_tensor_vec.push_back(acl_dst); + } + } + + size_t GROUP_SIZE = 128; + // GroupedMatmulV2 required tensor_list.size < 128 + for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) { + // split and call GroupedMatmulV2 + size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size()); + std::vector src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end); + std::vector src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end); + std::vector dst_tensor_vec_split(dst_tensor_vec.begin() + i, dst_tensor_vec.begin() + end); + + aclTensorList* src0_tensor_list = aclCreateTensorList(src0_tensor_vec_split.data(), src0_tensor_vec_split.size()); + aclTensorList* src1_tensor_list = aclCreateTensorList(src1_tensor_vec_split.data(), src1_tensor_vec_split.size()); + aclTensorList* dst_tensor_list = aclCreateTensorList(dst_tensor_vec_split.data(), dst_tensor_vec_split.size()); + + GGML_CANN_CALL_ACLNN_OP(ctx, GroupedMatmulV2, src1_tensor_list, src0_tensor_list, + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, -1, dst_tensor_list); + + ggml_cann_release_resources(ctx, src0_tensor_list, src1_tensor_list, dst_tensor_list); + } + return; +} + +/** + * @brief Performs expert-specific matrix multiplication (MoE) with + * quantized precision using the CANN backend. + * + * This function executes a matrix multiplication operation tailored for + * Mixture of Experts (MoE) models, where the input tensor is multiplied + * with expert-specific quantized weight matrices. It leverages the CANN + * backend to perform efficient low-precision computations and stores the + * quantized result in the destination tensor `dst`. + * + * Quantization techniques reduce memory footprint and improve performance + * by using lower-bit representations (e.g., int8) instead of floating-point. + * This function is designed to work with such formats and may incorporate + * optimizations like identity-based fast paths or routing masks for sparse + * expert selection. + * + * @param ctx The context for executing CANN backend operations. + * @param dst The destination tensor where the quantized MoE multiplication result + * will be stored. + * + * @note This function assumes quantized data types and is designed for + * MoE architectures with potential sparse expert routing. + */ +static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + // TODO: Use aclnnGroupedMatMul + //dst [M, K, N, 1] + ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1] + ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1 + ggml_tensor * ids = dst->src[2]; //ids [K, N] + + GGML_TENSOR_BINARY_OP_LOCALS + + // copy index from npu to cpu + int64_t n_as = ne02; // A + int64_t n_ids = ids->ne[0]; // K + + std::vector ids_host(ggml_nbytes(ids)); + ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids), + ACL_MEMCPY_DEVICE_TO_HOST); + ACL_CHECK(aclrtSynchronizeStream(ctx.stream())); + + char * src0_original = (char *) src0->data; + char * src1_original = (char *) src1->data; + char * dst_original = (char *) dst->data; + + ggml_tensor src0_row = *src0; + ggml_tensor src1_row = *src1; + ggml_tensor dst_row = *dst; + + const enum ggml_type type = dst->src[0]->type; + float weight_elem_size; + if (type == GGML_TYPE_Q4_0) { + weight_elem_size = float(sizeof(uint8_t)) / 2; + } else if (type == GGML_TYPE_Q8_0) { + weight_elem_size = float(sizeof(uint8_t)); + } else { + GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0 "); + } + + // src0_row [D, M, 1, 1] weight without permute + src0_row.ne[2] = 1; + src0_row.ne[3] = 1; + src0_row.nb[0] = weight_elem_size; + src0_row.nb[1] = weight_elem_size * ne00; + src0_row.nb[2] = weight_elem_size * ne00; + src0_row.nb[3] = weight_elem_size * ne00; + size_t weight_stride = ne00 * ne01 * weight_elem_size; + size_t weight_size = weight_stride * ne02 * ne03; + + // scale [D, M, 1, 1] -> scale && permute + size_t scale_elem_size = sizeof(uint16_t); + size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size; + + // src1_row [D, 1, 1, 1] -> input + src1_row.ne[1] = 1; + src1_row.ne[2] = 1; + src1_row.ne[3] = 1; + src1_row.nb[2] = nb11; + src1_row.nb[3] = nb11; + + // dst_row [M, 1, 1, 1] -> out + dst_row.ne[1] = 1; + dst_row.ne[2] = 1; + dst_row.ne[3] = 1; + dst_row.nb[2] = nb1; + dst_row.nb[3] = nb1; + + //create weight for one row + ggml_cann_pool_alloc weight_allocator(ctx.pool()); + void* weight_buffer = weight_allocator.alloc(nb02); + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + // expert index + int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + // If B = 1 (broadcast), always use 0; otherwise, use id. + int64_t i11 = (ne11 == 1 ? 0 : id); + int64_t i12 = iid1; + + int64_t i1 = id; + int64_t i2 = i12; + + void* src0_tmp_ptr = src0_original + i02*weight_stride; + void* scale_tmp_ptr = src0_original + weight_size + i02*scale_stride; + void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12; + void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2; + + // mem cpy + ggml_cann_async_memcpy(ctx, weight_buffer, src0_tmp_ptr, weight_stride, + ACL_MEMCPY_DEVICE_TO_DEVICE); + void* scale_buffer = (char*)weight_buffer + weight_stride; + ggml_cann_async_memcpy(ctx, scale_buffer, scale_tmp_ptr, scale_stride, + ACL_MEMCPY_DEVICE_TO_DEVICE); + + src0_row.data = weight_buffer; + src1_row.data = src1_tmp_ptr; + dst_row.data = dst_tmp_ptr; + dst_row.src[0] = &src0_row; + dst_row.src[1] = &src1_row; + + ggml_cann_mul_mat(ctx, &dst_row); + } + } + return; +} + +void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + const enum ggml_type type = dst->src[0]->type; + switch (type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + ggml_cann_mul_mat_id_fp(ctx, dst); + break; + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + ggml_cann_mul_mat_id_quant(ctx, dst); + break; + default: + GGML_ABORT("Unsupported type for mul_mat_id"); + break; + } +} + +void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){ + + ggml_tensor* src0 = dst->src[0]; // q, fp32 + ggml_tensor* src1 = dst->src[1]; // k, fp16 + ggml_tensor* src2 = dst->src[2]; // v, fp16 + ggml_tensor* src3 = dst->src[3]; // mask, fp16 + + float maxBias = 0.0f; + float scaleValue = 1.0f; + float logitSoftcap = 0.0f; + memcpy(&scaleValue, (float*)dst->op_params + 0, sizeof(float)); + memcpy(&maxBias, (float*)dst->op_params + 1, sizeof(float)); + memcpy(&logitSoftcap, (float*)dst->op_params + 2, sizeof(float)); + + if(logitSoftcap == 0.0f){ + size_t faElemSize = sizeof(uint16_t); + auto faDataType = ACL_FLOAT16; //ACL_BF16; + + aclTensor* acl_src0_f16_tensor = nullptr; + aclTensor* acl_src1_f16_tensor = nullptr; + aclTensor* acl_src2_f16_tensor = nullptr; + aclTensor* acl_dst_f16_tensor = nullptr; + + // Step 1: cast the src0 (Query) to fp16 if needed + ggml_cann_pool_alloc src0_f16_allocator(ctx.pool()); + void* src0_f16_buffer = nullptr; + + if(ggml_cann_type_mapping(src0->type) != faDataType){ + aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0); + src0_f16_buffer = src0_f16_allocator.alloc( + ggml_nelements(src0) * faElemSize); + + int64_t* src0_f16_ne = src0->ne; + size_t src0_f16_nb[GGML_MAX_DIMS]; + src0_f16_nb[0] = sizeof(uint16_t); + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + src0_f16_nb[i] = src0_f16_nb[i - 1] * src0_f16_ne[i - 1]; + } + + acl_src0_f16_tensor = ggml_cann_create_tensor( + src0_f16_buffer, faDataType, faElemSize, + src0_f16_ne, src0_f16_nb, GGML_MAX_DIMS + ); + aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType); + ggml_cann_release_resources(ctx, acl_src0_f32_tensor); + }else{ + acl_src0_f16_tensor = ggml_cann_create_tensor(src0); + } + + // Step 2: create the acl tensors for src1 (Key), src2 (Value), + // and the direct output from FusedInferAttention + + acl_src1_f16_tensor = ggml_cann_create_tensor(src1); + acl_src2_f16_tensor = ggml_cann_create_tensor(src2); + + ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); + void* out_f16_buffer = out_f16_allocator.alloc( + ggml_nelements(dst) * faElemSize); + + int64_t* out_f16_ne = src0->ne; + size_t out_f16_nb[GGML_MAX_DIMS]; + out_f16_nb[0] = faElemSize; + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1]; + } + + acl_dst_f16_tensor = ggml_cann_create_tensor( + out_f16_buffer, faDataType, faElemSize, + out_f16_ne, out_f16_nb, GGML_MAX_DIMS + ); + + // Step 3: create the PSEShift tensor if needed + // this tensor is considered as mask (f16) in the llama.cpp + + aclTensor* bcast_pse_tensor = nullptr; + int64_t bcast_pse_ne[GGML_MAX_DIMS]; + size_t bcast_pse_nb[GGML_MAX_DIMS]; + ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool()); + void* bcast_pse_buffer = nullptr; + + if(src3 != nullptr){ + bcast_pse_buffer = bcast_pse_allocator.alloc( + ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t)); + + if(src0->ne[1] > 1){ + // Case 1: broadcast pse for prefill stage with multiple head + aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3); + bcast_pse_ne[0] = src3->ne[0]; + bcast_pse_ne[1] = src3->ne[1]; + bcast_pse_ne[2] = src0->ne[2]; + bcast_pse_ne[3] = src3->ne[3]; + + bcast_pse_nb[0] = sizeof(uint16_t); + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1]; + } + + bcast_pse_tensor = ggml_cann_create_tensor( + bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), + bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS); + + int64_t repeats[] = {1, src0->ne[2], 1, 1}; + aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats); + + ggml_cann_release_resources(ctx, acl_mask_f16_tensor); + }else{ + // Case 2: trunc the first row and broadcast pse for decode stage with multiple head + int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]}; + size_t* trunc_pse_nb = src3->nb; + + aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor( + src3->data, ACL_FLOAT16, sizeof(uint16_t), + trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS); + + bcast_pse_ne[0] = src3->ne[0]; + bcast_pse_ne[1] = src0->ne[1]; + bcast_pse_ne[2] = src0->ne[2]; + bcast_pse_ne[3] = src3->ne[3]; + + bcast_pse_nb[0] = sizeof(uint16_t); + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1]; + } + + bcast_pse_tensor = ggml_cann_create_tensor( + bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t), + bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS); + + int64_t repeats[] = {1, src0->ne[2], 1, 1}; + aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats); + + ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor); + } + + // Compute the slope if needed. Derived from ggml_cann_softmax(). + if(maxBias != 0.0f){ + // alibi + const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3]; + const int64_t n_head = src0->ne[2]; + const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head)); + float m0 = powf(2.0f, -(maxBias) / n_heads_log2_floor); + float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor); + // init arange + ggml_cann_pool_alloc arange_allocator(ctx.pool(), + ne2_ne3 * faElemSize); + void* tmp_arange_buffer = arange_allocator.get(); + + // arange1: [1, ..., n_heads_log2_floor+1) + float start = 1; + float stop = n_heads_log2_floor + 1; + float step = 1; + int64_t n_elements_arange = n_heads_log2_floor; + + int64_t tmp_arange1_ne[] = {n_heads_log2_floor}; + size_t tmp_arange1_nb[] = {faElemSize}; + aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor( + tmp_arange_buffer, faDataType, faElemSize, + tmp_arange1_ne, tmp_arange1_nb, + GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + + aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange); + + aclTensor* tmp_arange2_tensor = nullptr; + if (n_heads_log2_floor < ne2_ne3) { + // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1) + start = 1; + stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1; + step = 2; + n_elements_arange = ne2_ne3 - n_heads_log2_floor; + int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor}; + size_t tmp_arange2_nb[] = {faElemSize}; + + aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor( + (char*)tmp_arange_buffer + + n_heads_log2_floor * faElemSize, + faDataType, faElemSize, + tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step, + n_elements_arange); + } + + // init mk_base + ggml_cann_pool_alloc mk_base_allocator(ctx.pool(), + ne2_ne3 * faElemSize); + void* tmp_mk_base_buffer = mk_base_allocator.get(); + int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor}; + size_t tmp_mk_base1_nb[] = {faElemSize}; + aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor( + tmp_mk_base_buffer, faDataType, faElemSize, + tmp_mk_base1_ne, tmp_mk_base1_nb, + GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + + aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor); + + aclTensor* tmp_mk_base2_tensor = nullptr; + if (n_heads_log2_floor < ne2_ne3) { + int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor}; + size_t tmp_mk_base2_nb[] = {faElemSize}; + aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor( + (char*)tmp_mk_base_buffer + + n_heads_log2_floor * faElemSize, + faDataType, faElemSize, + tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor); + } + + // init mk + int64_t tmp_mk_base_ne[] = {ne2_ne3}; + size_t tmp_mk_base_nb[] = {faElemSize}; + aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor( + tmp_mk_base_buffer, faDataType, faElemSize, + tmp_mk_base_ne, tmp_mk_base_nb, + GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + aclTensor* tmp_arange_tensor = ggml_cann_create_tensor( + tmp_arange_buffer, faDataType, faElemSize, + tmp_mk_base_ne, tmp_mk_base_nb, + GGML_MAX_DIMS - 3, ACL_FORMAT_ND); + aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor); + + // reshape mk + int64_t tmp_mk_ne[] = {1, 1, src0->ne[2], src0->ne[3]}; + size_t tmp_mk_nb[GGML_MAX_DIMS]; + tmp_mk_nb[0] = faElemSize; + for (int i = 1; i < GGML_MAX_DIMS; i++) { + tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1]; + } + aclTensor* tmp_mk_tensor = ggml_cann_create_tensor( + tmp_mk_base_buffer, faDataType, faElemSize, + tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS, + ACL_FORMAT_ND); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor); + + ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor, + tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor, + tmp_arange_tensor, tmp_mk_tensor); + } + } + + // Step 4: set the inputs for FusedInferAttention. + int kvTensorNum = 1; + aclTensor* acl_q_tensor = acl_src0_f16_tensor; + aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor}; + aclTensor* acl_v_tensors[] = {acl_src2_f16_tensor}; + auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum); + auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum); + + int64_t numHeads = src0->ne[2]; // N + int64_t numKeyValueHeads = src1->ne[2]; + // double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d) + int64_t preTokens = 65535; + int64_t nextTokens = 65535; + char layout[5] = {'B', 'N', 'S', 'D', 0}; + int64_t sparseMode = 0; + int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2; + int64_t blockSize = 0; + int64_t antiquantMode = 0; + bool softmaxLseFlag = false; + int64_t keyAntiquantMode = 0; + int64_t valueAntiquantMode = 0; + + // Step 5: launch the FusedInferAttentionScoreV2 kernel. + // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md + + GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2, + acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v + bcast_pse_tensor, nullptr, // pse, mask + nullptr, nullptr, // actSeqLen, actSeqLenkv + nullptr, nullptr, // deqScale1, quantScale1 + nullptr, nullptr, nullptr, // deqScale2, quantScale2, quantOffset2 + nullptr, nullptr, // antiquantScale, antiquantOffset + nullptr, // blockTable + nullptr, nullptr, // qPadSize, kvPadSize + nullptr, nullptr, // kAntiquantScale, kAntiQuantOffset + nullptr, nullptr, // vAntiquantScale, vAntiQuantOffset + nullptr, nullptr, nullptr, // kSharedPrefix, vSharedPrefix, actSharedLen + numHeads, scaleValue, // heads, scaleValue + preTokens, nextTokens, // preTokens, nextTokens + layout, // inputLayout + numKeyValueHeads, // numKVHeads + sparseMode, innerPrecise, // sparseMode, innerPrecise + blockSize, antiquantMode, // blockSize, antiquantMode + softmaxLseFlag, // softmaxLseFlag + keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode + acl_dst_f16_tensor, // attentionOut + nullptr // softmaxLse + ); + + // Step 6: post-processing, permute and cast to f32 + + int64_t new_dim[] = {0, 2, 1, 3}; + aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst); + + if(ggml_cann_type_mapping(dst->type) != faDataType){ + ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool()); + perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize); + void* perm_out_f16_buffer = perm_out_f16_allocator.get(); + + int64_t* perm_out_f16_ne = dst->ne; + size_t perm_out_f16_nb[GGML_MAX_DIMS]; + perm_out_f16_nb[0] = faElemSize; + for(int i = 1; i < GGML_MAX_DIMS; ++i){ + perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1]; + } + aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor( + perm_out_f16_buffer, faDataType, faElemSize, + perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS); + aclnn_permute(ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS); + aclnn_cast(ctx, + acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type)); + ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor); + }else{ + // only need to permute + aclnn_permute(ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS); + } + ggml_cann_release_resources(ctx, acl_src0_f16_tensor, + acl_src1_f16_tensor, + acl_src2_f16_tensor, + acl_dst_f16_tensor, + acl_dst_tensor); + if(src3 != nullptr){ + ggml_cann_release_resources(ctx, bcast_pse_tensor); + } + }else{ + GGML_ABORT("Function is not implemented."); + } } diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h old mode 100644 new mode 100755 index b2d1b3c36..80ce80bae --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -23,6 +23,7 @@ #ifndef CANN_ACLNN_OPS #define CANN_ACLNN_OPS +#include #include #include #include @@ -713,6 +714,312 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst); */ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst); +/** + * @brief Performs the Flash Attention extended operator using the CANN backend. + * + * @details This function implements the memory-efficient Flash Attention algorithm + * for computing scaled dot-product attention with hardware acceleration. + * The result is stored in the destination tensor `dst`. + * + * This operation is accelerated using the CANN backend to improve runtime performance. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the result will be stored. + * dst->op is expected to be `GGML_OP_FLASH_ATTN_EXT`. + */ +void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst); + +/* + * @brief A generic wrapper for ACL resources with custom deleter support. + */ +using any_acl_resource = std::unique_ptr>; + +/** + * @brief Trait structure used to define how to destroy a given ACL resource type. + * + * @tparam T ACL resource type. + */ +template +struct acl_resource_traits; + +/** + * @brief Specialization for aclTensor, defines how to destroy an aclTensor resource. + */ +template<> +struct acl_resource_traits { + static void destroy(void* p) { + ACL_CHECK(aclDestroyTensor(static_cast(p))); + } +}; + +/** + * @brief Specialization for aclIntArray, defines how to destroy an aclIntArray resource. + */ +template<> +struct acl_resource_traits { + static void destroy(void* p) { + ACL_CHECK(aclDestroyIntArray(static_cast(p))); + } +}; + +/** + * @brief Specialization for aclScalar, defines how to destroy an aclScalar resource. + */ +template<> +struct acl_resource_traits { + static void destroy(void* p) { + ACL_CHECK(aclDestroyScalar(static_cast(p))); + } +}; + +/** + * @brief Specialization for aclTensorList, defines how to destroy an aclTensorList resource. + */ +template<> +struct acl_resource_traits { + static void destroy(void* p) { + ACL_CHECK(aclDestroyTensorList(static_cast(p))); + } +}; + +/** + * @brief Creates a generic ACL resource wrapper with proper destruction logic. + * + * @tparam T ACL resource type. + * @param ptr Raw pointer to ACL resource. + * @return any_acl_resource Smart pointer that handles destruction. + */ +template +any_acl_resource make_acl_resource(T* ptr) { + return any_acl_resource( + static_cast(ptr), + [](void* p) { + acl_resource_traits::destroy(p); + } + ); +} + +/** + * @brief Registers multiple ACL resources into a vector for lifetime management. + * + * @tparam Args Variadic list of ACL resource types. + * @param vec Target vector to hold ACL resources. + * @param args Raw pointers to ACL resources. + */ +template +void register_acl_resources(std::vector& vec, Args*... args) { + (vec.emplace_back(make_acl_resource(args)), ...); +} + +/** + * @brief Task class that wraps the execution of an aclnn function call. + */ +class aclnn_task : public cann_task { + public: + aclnn_task(aclnn_func_t aclnn_func, void * workspace_addr, + uint64_t workspace_size, aclOpExecutor * executor, + aclrtStream stream) : + aclnn_func_(aclnn_func), + workspace_addr_(workspace_addr), + workspace_size_(workspace_size), + executor_(executor), + stream_(stream) {} + virtual void run_task() override { + ACL_CHECK(aclnn_func_(workspace_addr_, workspace_size_, executor_, stream_)); + } + private: + aclnn_func_t aclnn_func_; + void * workspace_addr_; + uint64_t workspace_size_; + aclOpExecutor * executor_; + aclrtStream stream_; +}; + +/** + * @brief Task class that releases ACL resources after usage. + */ +class release_resource_task : public cann_task { +public: + release_resource_task(std::vector&& resources){ + resource_ = std::move(resources); + } + + virtual void run_task() override { + resource_.clear(); + } +private: + std::vector resource_; +}; + +/** + * @brief Task class for performing asynchronous memory copy operations. + */ +class async_memcpy_task : public cann_task { +public: + async_memcpy_task(void* dst, const void* src, size_t size, + aclrtMemcpyKind kind, aclrtStream stream) + : dst_(dst), src_(src), size_(size), kind_(kind), stream_(stream) {} + + virtual void run_task() override { + ACL_CHECK(aclrtMemcpyAsync(dst_, size_, src_, size_, kind_, stream_)); + } +private: + void* dst_; + const void* src_; + size_t size_; + aclrtMemcpyKind kind_; + aclrtStream stream_; +}; + +/** + * @brief Task class for performing asynchronous memory set operations. + */ +class async_memset_task : public cann_task { + public: + async_memset_task(void* buffer, size_t size, int32_t value, aclrtStream stream) + : buffer_(buffer), size_(size), value_(value), stream_(stream) {} + + virtual void run_task() override { + ACL_CHECK(aclrtMemsetAsync(buffer_, size_, value_, size_, stream_)); + } + private: + void* buffer_; + size_t size_; + int32_t value_; + aclrtStream stream_; +}; + +/** + * @brief Launches an asynchronous task using the memory allocator. + * + * This macro submit an asynchronous task on the specified stream. + * The task uses memory allocated by the allocator. It is guaranteed + * that the memory will not be accessed by other tasks until this task + * completes, due to the sequential execution order within the same stream. + * + * @param OP_NAME aclnn operator name. + * @param args Additional arguments required by the task. + * + * @note + * Memory from the allocator will be "freed" immediately and can be + * reallocated to other pointers. However, it won't be accessed by any + * other task before this asynchronous task ends, because all tasks in the + * same stream are executed in queue order. + */ + +#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \ + do { \ + uint64_t workspaceSize = 0; \ + aclOpExecutor * executor; \ + void * workspaceAddr = nullptr; \ + ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor));\ + /* workspace should alloced in main thread to keep malloc order when using vmm. */ \ + if (workspaceSize > 0) { \ + ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \ + workspaceAddr = workspace_allocator.get(); \ + } \ + if (CTX.async_mode) { \ + auto task = \ + std::make_unique(aclnn##OP_NAME, workspaceAddr, workspaceSize, \ + executor, CTX.stream()); \ + CTX.task_queue.submit_task(std::move(task)); \ + } else { \ + ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream()));\ + } \ + } while (0) + +/** + * @brief Registers and releases multiple ACL resources, optionally deferring the release + * using a task. + * + * @tparam Args Types of the ACL resources. + * @param ctx Backend context which manages task submission and async mode. + * @param args Pointers to ACL resources to be released. + */ +template +void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... args) { + std::vector resources; + register_acl_resources(resources, std::forward(args)...); + if(ctx.async_mode) { + auto task = std::make_unique(std::move(resources)); + ctx.task_queue.submit_task(std::move(task)); + } +} + +/** + * @brief Performs an asynchronous memory copy operation, optionally deferred via task submission. + * + * @param ctx Backend context containing stream and async configuration. + * @param dst Destination memory address. + * @param src Source memory address. + * @param len Size of memory to copy (in bytes). + * @param kind Type of memory copy (host-to-device, device-to-host, etc). + */ +inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx, void * dst, + const void * src, size_t len, aclrtMemcpyKind kind) { + if (ctx.async_mode) { + auto task = std::make_unique(dst, const_cast(src), len, kind, ctx.stream()); + ctx.task_queue.submit_task(std::move(task)); + } else { + ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx.stream())); + } +} + +inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx, void * dst, + const void * src, size_t len, aclrtMemcpyKind kind) { + if (ctx->async_mode) { + auto task = std::make_unique(dst, const_cast(src), len, kind, ctx->stream()); + ctx->task_queue.submit_task(std::move(task)); + } else { + ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx->stream())); + } +} + +/** + * @brief Performs an asynchronous memory set operation, optionally deferred via task submission. + * + * @param ctx Backend context containing stream and async configuration. + * @param buffer Memory buffer to be set. + * @param size Size of the memory buffer (in bytes). + * @param value Value to set in the buffer. + */ +inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffer, + size_t size, int value) { + if (ctx.async_mode) { + auto task = std::make_unique(buffer, size, value, ctx.stream()); + ctx.task_queue.submit_task(std::move(task)); + } else { + ACL_CHECK(aclrtMemsetAsync(buffer, size, value, size, ctx.stream())); + } +} + +/** + * @brief Performs sparse expert-based matrix multiplication using the CANN backend. + * + * @details This function implements a MoE-style batched matrix multiplication, where each input token + * is routed to one or more experts, and each expert corresponds to a specific [D, M] weight matrix + * in the source tensor `src0`. The routing indices are provided via the `ids` tensor. + * + * For each token (from `src1`), the function selects the corresponding expert(s) as specified by `ids`, + * performs the matrix multiplication with the selected expert's weight submatrix (from `src0`), + * and stores the results in `dst`. This operation is optimized and executed on the CANN backend. + * + * Dimensions: + * - src0: [D, M, A, 1], where A is the number of experts + * - src1: [D, B, N, 1], where N is batch size and B is the slot count per sample + * - ids : [K, N], where K is the number of experts each token is routed to + * - dst : [M, K, N, 1], output tensor storing the result of expert × token multiplication + * + * The function handles two main modes: + * - If `ne12 == 1`, a simpler per-token loop is used. + * - TODO: If `ne12 > 1`, grouped multiplication and memory copying is used for efficiency. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor where the expert-weighted token outputs are stored. + * Expected to be of shape [M, K, N, 1]. + */ +void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst); + /** * @brief Applies a element-wise operation to two input tensors using the CANN * backend. @@ -742,42 +1049,9 @@ void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) { bcast_shape(src0, src1, dst, &acl_src0, &acl_src1, &acl_dst); binary_op(ctx, acl_src0, acl_src1, acl_dst); - ACL_CHECK(aclDestroyTensor(acl_src0)); - ACL_CHECK(aclDestroyTensor(acl_src1)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + ggml_cann_release_resources(ctx, acl_src0, acl_src1, acl_dst); } -/** - * @brief Launches an asynchronous task using the memory allocator. - * - * This macro submit an asynchronous task on the specified stream. - * The task uses memory allocated by the allocator. It is guaranteed - * that the memory will not be accessed by other tasks until this task - * completes, due to the sequential execution order within the same stream. - * - * @param OP_NAME aclnn operator name. - * @param args Additional arguments required by the task. - * - * @note - * Memory from the allocator will be "freed" immediately and can be - * reallocated to other pointers. However, it won't be accessed by any - * other task before this asynchronous task ends, because all tasks in the - * same stream are executed in queue order. - */ -#define GGML_CANN_CALL_ACLNN_OP(OP_NAME, ...) \ - do { \ - uint64_t workspaceSize = 0; \ - aclOpExecutor * executor; \ - void * workspaceAddr = nullptr; \ - \ - ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \ - \ - if (workspaceSize > 0) { \ - ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize); \ - workspaceAddr = workspace_allocator.get(); \ - } \ - ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, ctx.stream())); \ - } while (0) /** * @brief Applies a unary operation to an input tensor using the CANN backend. @@ -799,9 +1073,7 @@ template aclTensor* acl_dst = ggml_cann_create_tensor(dst); unary_op(ctx, acl_src, acl_dst); - - ACL_CHECK(aclDestroyTensor(acl_src)); - ACL_CHECK(aclDestroyTensor(acl_dst)); + ggml_cann_release_resources(ctx, acl_src, acl_dst); } /** @@ -832,7 +1104,7 @@ void ggml_cann_unary_op( * * Internally, the lambda will call: * @code - * GGML_CANN_CALL_ACLNN_OP(OP_NAME, acl_src, acl_dst); + * GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); * @endcode * * @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP. @@ -840,14 +1112,14 @@ void ggml_cann_unary_op( * @see ggml_cann_unary_op * @see GGML_CANN_CALL_ACLNN_OP */ -#define GGML_CANN_CALL_UNARY_OP(OP_NAME) \ - do { \ - auto lambda = [](ggml_backend_cann_context& ctx, \ - aclTensor* acl_src, \ - aclTensor* acl_dst) { \ - GGML_CANN_CALL_ACLNN_OP(OP_NAME, acl_src, acl_dst); \ - }; \ - ggml_cann_unary_op(lambda, ctx, dst); \ - } \ +#define GGML_CANN_CALL_UNARY_OP(OP_NAME) \ + do { \ + auto lambda = [](ggml_backend_cann_context& ctx, \ + aclTensor* acl_src, \ + aclTensor* acl_dst) { \ + GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \ + }; \ + ggml_cann_unary_op(lambda, ctx, dst); \ + } \ while (0) #endif // CANN_ACLNN_OPS diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h old mode 100644 new mode 100755 index 5164cb74e..ba2cef0c2 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -31,9 +31,17 @@ #include #include #include +#include +#include +#include +#include +#include +#include +#include #include "../include/ggml-cann.h" #include "../include/ggml.h" +#include "../ggml-impl.h" #define MATRIX_ROW_PADDING 512 #define GGML_CANN_MAX_STREAMS 8 @@ -96,6 +104,9 @@ const ggml_cann_device_info& ggml_cann_info(); void ggml_cann_set_device(int32_t device); int32_t ggml_cann_get_device(); +std::optional get_env(const std::string& name); +bool parse_bool(const std::string& value); + /** * @brief Abstract base class for memory pools used by CANN. */ @@ -205,6 +216,127 @@ struct ggml_cann_pool_alloc { ggml_cann_pool_alloc& operator=(ggml_cann_pool_alloc&&) = delete; }; +/** + * @brief Function pointer type for ACLNN operator calls. + */ +using aclnn_func_t = aclnnStatus (*)(void*, uint64_t, aclOpExecutor*, aclrtStream); + +/** + * @brief Base class for all CANN tasks to be submitted to the task queue. + * + * Users should override the run_task() method with actual task logic. + */ +class cann_task { +public: + virtual void run_task() {} +}; + +/** + * @brief A lock-free ring-buffer based task queue for asynchronously executing cann_task instances. + */ +class cann_task_queue { +public: + /** + * @brief Constructs a task queue with a fixed power-of-two capacity for a specific device. + * + * @param capacity Queue capacity. Must be a power of 2. + * @param device Target device ID (used for context setting). + */ + explicit cann_task_queue(size_t capacity, int32_t device) + : buffer_(capacity), capacity_(capacity), head_(0), tail_(0), + running_(false), device_(device) { + GGML_ASSERT((capacity & (capacity - 1)) == 0 && "capacity must be power of 2"); + mask_ = capacity_ - 1; + } + + /** + * @brief Attempts to enqueue a task into the queue. + * + * @param item Unique pointer to the task. + * @return true if the task was successfully enqueued, false if the queue was full. + */ + bool enqueue(std::unique_ptr&& item) { + size_t next_tail = (tail_ + 1) & mask_; + + if (next_tail == head_) { + return false; + } + + buffer_[tail_] = std::move(item); + std::atomic_thread_fence(std::memory_order_release); + tail_ = next_tail; + + return true; + } + + /** + * @brief Submits a task to the queue, and starts the worker thread if not already running. + * + * @param task Task to be submitted. + */ + void submit_task(std::unique_ptr&& task) { + while(!enqueue(std::move(task))) { + std::this_thread::yield(); + continue; + } + + if (!running_) { + running_ = true; + thread_ = std::thread(&cann_task_queue::execute, this); + } + + } + + /** + * @brief Waits until the queue is completely empty and no tasks are being processed. + */ + void wait() { + while (running_ && head_ != tail_) { + std::this_thread::yield(); + continue; + } + } + + /** + * @brief Stops the task queue and joins the worker thread. + */ + void stop() { + running_ = false; + if (thread_.joinable()) { + thread_.join(); + } + } + +private: + /** + * @brief Worker thread function that continuously dequeues and executes tasks. + */ + void execute() { + ggml_cann_set_device(device_); + + while (running_) { + if(head_ == tail_) { + std::this_thread::yield(); + continue; + } + + std::atomic_thread_fence(std::memory_order_acquire); + buffer_[head_]->run_task(); + buffer_[head_].reset(); + head_ = (head_ + 1) & mask_; + } + } + + std::vector> buffer_; + const size_t capacity_; + size_t mask_; + size_t head_; + size_t tail_; + bool running_; + std::thread thread_; + int32_t device_; +}; + /** * @brief Context for managing CANN backend operations. */ @@ -213,6 +345,8 @@ struct ggml_backend_cann_context { std::string name; /**< Name of the device. */ std::string description; /**< Description of the device. */ aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */ + cann_task_queue task_queue; + bool async_mode; aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */ @@ -221,9 +355,13 @@ struct ggml_backend_cann_context { * @param device Device ID. */ explicit ggml_backend_cann_context(int device) - : device(device), name("CANN" + std::to_string(device)) { + : device(device), name("CANN" + std::to_string(device)), task_queue(1024, device) { ggml_cann_set_device(device); description = aclrtGetSocName(); + + bool async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or("")); + GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__, + device, async_mode ? "ON" : "OFF"); } /** @@ -231,6 +369,7 @@ struct ggml_backend_cann_context { */ ~ggml_backend_cann_context() { ggml_cann_set_device(device); + task_queue.stop(); if (copy_event != nullptr) { ACL_CHECK(aclrtDestroyEvent(copy_event)); } diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp old mode 100644 new mode 100755 index cec36b36e..d1a0ad374 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -29,11 +29,16 @@ #include #include #include +#include +#include +#include +#include #include "ggml-impl.h" #include "ggml-backend-impl.h" #include "ggml-cann/aclnn_ops.h" #include "ggml-cann/common.h" +#include "ggml.h" #define GGML_COMMON_DECL_C @@ -90,6 +95,26 @@ int32_t ggml_cann_get_device() { return id; } +/** + * @brief Get the value of the specified environment variable (name). + * if not empty, return a std::string object + */ +std::optional get_env(const std::string& name) { + const char* val = std::getenv(name.c_str()); + if (!val) return std::nullopt; + std::string res = std::string(val); + std::transform(res.begin(), res.end(), res.begin(), ::tolower); + return res; +} + +/** + * @brief Verify whether the environment variable is a valid value. + */ +bool parse_bool(const std::string& value) { + std::unordered_set valid_values = {"on", "1", "yes", "y", "enable", "true"}; + return valid_values.find(value) != valid_values.end(); +} + /** * @brief Initialize the CANN device information. * @@ -119,9 +144,10 @@ static ggml_cann_device_info ggml_cann_init() { prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE; prop.location.id = id; prop.reserve = 0; - ACL_CHECK(aclrtMemGetAllocationGranularity( + err = aclrtMemGetAllocationGranularity( &prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED, - &info.devices[id].vmm_granularity)); + &info.devices[id].vmm_granularity); + info.devices[id].vmm = err == ACL_SUCCESS; size_t free, total; ggml_backend_cann_get_device_memory(id, &free, &total); @@ -148,11 +174,223 @@ const ggml_cann_device_info& ggml_cann_info() { //#define DEBUG_CANN_MALLOC /** - * @brief A pool of CANN buffers(legacy). + * @brief A pool of CANN buffers(priority segment buffer). * * This class manages a pool of CANN buffers for a specific device. */ -struct ggml_cann_pool_leg : public ggml_cann_pool { +struct ggml_cann_pool_buf_prio : public ggml_cann_pool { + /** + * @brief The maximum reuse margin for a buffer. + */ + static const size_t max_reuse_margin = 1ull << 22; // 4MB + + /** + * @brief The minimum free margin for a buffer. + */ + static const size_t min_free_margin = 1ull << 20; // 1MB + + /** + * @brief The alignment for buffer allocation. + */ + static const size_t alignment = 128; + + /** + * @brief The device ID associated with this buffer pool. + */ + int device; + + /** + * @brief Whether to disable clean during buffer allocation. + */ + bool disable_clean = false; + + /** + * @brief Structure representing a CANN buffer. + */ + struct ggml_cann_buffer { + void* ptr = nullptr; ///< Pointer to the buffer. + size_t size = 0; ///< Size of the buffer. + std::chrono::steady_clock::time_point last_used; ///< Last used time. + + bool operator>(const ggml_cann_buffer& other) const { + return size > other.size; + } + }; + + /** + * @brief Array of CANN buffers in the pool. + */ + std::unordered_map buffer_pool; + std::priority_queue, + std::greater<>> free_buffers ; + + /** + * @brief Total size of all buffers in the pool. + */ + size_t pool_size = 0; + + /** + * @brief Constructor to initialize the buffer pool for a specific device. + * + * @param device The device ID to associate with this buffer pool. + */ + explicit ggml_cann_pool_buf_prio(int device) : device(device) { + disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or("")); + } + + /** + * @brief Destructor to free all buffers in the pool. + */ + ~ggml_cann_pool_buf_prio() { + ggml_cann_set_device(device); + for (auto& [b_ptr, b_size] : buffer_pool) { + aclrtFree(b_ptr); + pool_size -= b_size; + } + buffer_pool.clear(); + GGML_ASSERT(pool_size == 0); + } + + /** + * @brief Allocate a buffer of the given size. + * + * @param size The size of the buffer to allocate. + * @param actual_size A pointer to a variable to receive the actual size of + * the allocated buffer. + * @return A pointer to the allocated buffer. + */ + void* alloc(size_t size, size_t* actual_size) override { + size = GGML_PAD(size, alignment); + if (size == 0) { + size = alignment; + } + + void* ptr = nullptr; + auto now = std::chrono::steady_clock::now(); + + std::vector free_buffers_rest; + free_buffers_rest.reserve(free_buffers.size()); + while (!free_buffers.empty()) { + auto b = free_buffers.top(); + free_buffers.pop(); + + if (b.size >= size) { + // reuse the buffer if the size is enough + const size_t margin = b.size - size; + if (margin <= max_reuse_margin) { + *actual_size = b.size; + ptr = b.ptr; +#ifdef DEBUG_CANN_MALLOC + GGML_LOG_INFO( + "cann pool[%d]: reused %p, " + "pool_size = %5u MB, " + "size = %5u MB, " + "margin = %5u MB\n", + device, b.ptr, + (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(margin, 1048576) / 1048576)); +#endif + break; + } + } + + bool should_clean = !disable_clean && + b.size > min_free_margin && + std::chrono::duration_cast(now - b.last_used).count() > 100; + if (should_clean) { + // free the buffer if the size is needed to be freed + ACL_CHECK(aclrtFree(b.ptr)); + pool_size -= b.size; + buffer_pool.erase(b.ptr); +#ifdef DEBUG_CANN_MALLOC + GGML_LOG_INFO( + "cann pool[%d]: clean %p, " + "pool_size = %5u MB, " + "size = %5u MB\n", + device, b.ptr, + (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(b.size, 1048576) / 1048576)); +#endif + continue; + } + free_buffers_rest.push_back(b); + } + for (ggml_cann_buffer &b : free_buffers_rest) { + free_buffers.push(std::move(b)); + } + +#ifdef DEBUG_CANN_MALLOC + GGML_LOG_INFO("cann pool[%d] free pool_size = %5u MB\n\n", device, (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576)); +#endif + if (ptr != nullptr) { + return ptr; + } + + // allocate a new buffer if no buffer can be reused + ggml_cann_set_device(device); + ACL_CHECK(aclrtMalloc(&ptr, size, ACL_MEM_MALLOC_HUGE_FIRST)); + *actual_size = size; + pool_size += size; +#ifdef DEBUG_CANN_MALLOC + GGML_LOG_INFO( + "cann pool[%d]: allocate %p, " + "pool_size = %5u MB, " + "size = %5u MB\n", + device, ptr, (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(size, 1048576) / 1048576)); +#endif + buffer_pool.emplace(ptr, size); + return ptr; + } + + /** + * @brief Free a buffer and return it to the pool. + * + * @param ptr Pointer to the buffer to free. + * @param size Size of the buffer to free. + */ + void free(void* ptr, size_t size) override { + GGML_UNUSED(size); + auto it = buffer_pool.find(ptr); + if (it == buffer_pool.end()) { + GGML_ABORT("cann pool[%d]: buffer %p not found in pool\n", device, ptr); + } + + auto now = std::chrono::steady_clock::now(); + free_buffers.emplace(ggml_cann_buffer{ptr, it->second, now}); +#ifdef DEBUG_CANN_MALLOC + GGML_LOG_INFO( + "cann pool[%d]: return %p, " + "pool_size = %5u MB\n", + device, ptr, + (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576)); +#endif + } +}; + +/** + * @brief A pool of CANN buffers(segment buffer). + * + * This class manages a pool of CANN buffers for a specific device. + */ +struct ggml_cann_pool_buf : public ggml_cann_pool { + /** + * @brief The maximum reuse margin for a buffer. + */ + static const size_t max_reuse_margin = 1ull << 22; // 4MB + + /** + * @brief The minimum free margin for a buffer. + */ + static const size_t min_free_margin = 1ull << 20; // 1MB + + /** + * @brief The alignment for buffer allocation. + */ + static const size_t alignment = 128; + /** * @brief The maximum number of buffers in the pool. */ @@ -163,12 +401,19 @@ struct ggml_cann_pool_leg : public ggml_cann_pool { */ int device; + /** + * @brief Whether to disable clean during buffer allocation. + */ + bool disable_clean = false; + /** * @brief Structure representing a CANN buffer. */ struct ggml_cann_buffer { void* ptr = nullptr; ///< Pointer to the buffer memory. size_t size = 0; ///< Size of the buffer. + bool used = false; ///< Whether the buffer is currently in use. + std::chrono::steady_clock::time_point last_used; ///< Last used time. }; /** @@ -186,17 +431,19 @@ struct ggml_cann_pool_leg : public ggml_cann_pool { * * @param device The device ID to associate with this buffer pool. */ - explicit ggml_cann_pool_leg(int device) : device(device) {} + explicit ggml_cann_pool_buf(int device) : device(device) { + disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or("")); + } /** * @brief Destructor to free all buffers in the pool. */ - ~ggml_cann_pool_leg() { + ~ggml_cann_pool_buf() { ggml_cann_set_device(device); for (int i = 0; i < MAX_BUFFERS; ++i) { ggml_cann_buffer& b = buffer_pool[i]; if (b.ptr != nullptr) { - ACL_CHECK(aclrtFree(b.ptr)); + aclrtFree(b.ptr); pool_size -= b.size; } } @@ -212,63 +459,93 @@ struct ggml_cann_pool_leg : public ggml_cann_pool { * @return A pointer to the allocated buffer. */ void* alloc(size_t size, size_t* actual_size) override { - const size_t alignment = 128; size = GGML_PAD(size, alignment); if (size == 0) { size = alignment; } -#ifdef DEBUG_CANN_MALLOC - int nnz = 0; - size_t max_size = 0; -#endif - size_t best_diff = 1ull << 36; - int ibest = -1; - for (int i = 0; i < MAX_BUFFERS; ++i) { + + void* ptr = nullptr; + auto now = std::chrono::steady_clock::now(); + + int i = 0; + for (; i < MAX_BUFFERS; ++i) { ggml_cann_buffer& b = buffer_pool[i]; - if (b.ptr != nullptr) { + if (b.ptr == nullptr) { + break; + } + if (b.used) { + continue; + } + if (b.size >= size) { + // reuse the buffer if the size is enough + const size_t margin = b.size - size; + if (margin <= max_reuse_margin) { + *actual_size = b.size; + b.used = true; + ptr = b.ptr; #ifdef DEBUG_CANN_MALLOC - ++nnz; - if (b.size > max_size) max_size = b.size; + GGML_LOG_INFO( + "cann pool[%d]: reused %p, " + "pool_size = %5u MB, " + "size = %5u MB, " + "margin = %5u MB\n", + device, b.ptr, + (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(margin, 1048576) / 1048576)); #endif - if (b.size >= size) { - size_t diff = b.size - size; - if (diff < best_diff) { - best_diff = diff; - ibest = i; - if (!best_diff) { - void* ptr = b.ptr; - *actual_size = b.size; - b.ptr = nullptr; - b.size = 0; - return ptr; - } - } + break; } } + + bool should_clean = !disable_clean && + b.size > min_free_margin && + std::chrono::duration_cast(now - b.last_used).count() > 100; + if (should_clean) { + // free the buffer if the size is needed to be freed + ACL_CHECK(aclrtFree(b.ptr)); + pool_size -= b.size; +#ifdef DEBUG_CANN_MALLOC + GGML_LOG_INFO( + "cann pool[%d]: clean %p, " + "pool_size = %5u MB, " + "size = %5u MB\n", + device, b.ptr, + (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(b.size, 1048576) / 1048576)); +#endif + b.ptr = nullptr; + } } - if (ibest >= 0) { - ggml_cann_buffer& b = buffer_pool[ibest]; - void* ptr = b.ptr; - *actual_size = b.size; - b.ptr = nullptr; - b.size = 0; + if (ptr != nullptr) { return ptr; } - void* ptr; - ggml_cann_set_device(device); - ACL_CHECK( - aclrtMalloc(&ptr, size, ACL_MEM_MALLOC_HUGE_FIRST)); - *actual_size = size; - pool_size += size; + + if (i < MAX_BUFFERS) { + // allocate a new buffer if no buffer can be reused + ggml_cann_buffer& b = buffer_pool[i]; + ggml_cann_set_device(device); + ACL_CHECK(aclrtMalloc(&b.ptr, size, ACL_MEM_MALLOC_HUGE_FIRST)); + pool_size += size; + *actual_size = size; + b.size = size; + b.used = true; + if (i >= MAX_BUFFERS - 8) { + GGML_LOG_WARN("cann pool[%d]: slots almost full\n", device); + } #ifdef DEBUG_CANN_MALLOC - GGML_LOG_INFO( - "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, " - "requested %u MB\n", - __func__, device, nnz, (uint32_t)(max_size / 1024 / 1024), - (uint32_t)(pool_size / 1024 / 1024), - (uint32_t)(size / 1024 / 1024)); + GGML_LOG_INFO( + "cann pool[%d]: allocate %p, " + "pool_size = %5u MB, " + "size = %5u MB\n", + device, b.ptr, + (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576), + (uint32_t)(GGML_PAD(b.size, 1048576) / 1048576)); #endif - return ptr; + return b.ptr; + } + + GGML_ABORT("cann pool[%d]: slots full\n", device); } /** @@ -278,18 +555,24 @@ struct ggml_cann_pool_leg : public ggml_cann_pool { * @param size Size of the buffer to free. */ void free(void* ptr, size_t size) override { + GGML_UNUSED(size); for (int i = 0; i < MAX_BUFFERS; ++i) { ggml_cann_buffer& b = buffer_pool[i]; - if (b.ptr == nullptr) { - b.ptr = ptr; - b.size = size; - return; + if (b.ptr != ptr) { + continue; } + b.used = false; + b.last_used = std::chrono::steady_clock::now(); +#ifdef DEBUG_CANN_MALLOC + GGML_LOG_INFO( + "cann pool[%d]: return %p, " + "pool_size = %5u MB\n", + device, b.ptr, + (uint32_t)(GGML_PAD(pool_size, 1048576) / 1048576)); +#endif + return; } - // memory should always buffered. these memory may still needed by - // tasks in stream. - // TODO, fix me. - GGML_ABORT("Cann buffer pool full, increase MAX_CANN_BUFFERS\n"); + GGML_ABORT("cann pool[%d]: slots full\n", device); } }; @@ -347,8 +630,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { * @param device The device ID to associate with this buffer pool. */ explicit ggml_cann_pool_vmm(int device) - : device(device), - granularity(ggml_cann_info().devices[device].vmm_granularity) { + : device(device) { auto dev = ggml_cann_info().devices[device]; granularity = dev.vmm_granularity; max_size = dev.total_vram; @@ -471,7 +753,20 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { */ std::unique_ptr ggml_backend_cann_context::new_pool_for_device( int device) { - return std::unique_ptr(new ggml_cann_pool_vmm(device)); + std::string mem_pool_type = get_env("GGML_CANN_MEM_POOL").value_or(""); + + if (mem_pool_type == "prio") { + GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device); + return std::unique_ptr(new ggml_cann_pool_buf_prio(device)); + } + + if (ggml_cann_info().devices[device].vmm && mem_pool_type != "leg") { + GGML_LOG_INFO("%s: device %d use vmm pool\n", __func__, device); + return std::unique_ptr(new ggml_cann_pool_vmm(device)); + } + + GGML_LOG_INFO("%s: device %d use buffer pool\n", __func__, device); + return std::unique_ptr(new ggml_cann_pool_buf(device)); } // cann buffer @@ -1020,8 +1315,11 @@ ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, ggml_cann_set_device(buft_ctx->device); - size = std::max(size, (size_t)1); - + const size_t alignment = 128; + size = GGML_PAD(size, alignment); + if (size == 0) { + size = alignment; + } void* dev_ptr; aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST); if (err != ACL_SUCCESS) { @@ -1333,7 +1631,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, auto lambda = [](ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) { - GGML_CANN_CALL_ACLNN_OP(GeluV2, acl_src, 0, acl_dst); + GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst); }; ggml_cann_unary_op(lambda, ctx, dst); } break; @@ -1399,7 +1697,8 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, ggml_cann_mul_mat(ctx, dst); break; case GGML_OP_MUL_MAT_ID: - return false; + ggml_cann_mul_mat_id(ctx, dst); + break; case GGML_OP_SCALE: ggml_cann_scale(ctx, dst); break; @@ -1474,6 +1773,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, case GGML_OP_COUNT_EQUAL: ggml_cann_count_equal(ctx, dst); break; + case GGML_OP_FLASH_ATTN_EXT: + ggml_cann_flash_attn_ext(ctx, dst); + break; default: return false; } @@ -1516,12 +1818,11 @@ static void ggml_backend_cann_free(ggml_backend_t backend) { delete backend; } + /** * @brief Sets tensor data asynchronously in the CANN backend. * - * This function asynchronously sets tensor data in the CANN backend. Depending - * on the tensor type, it may perform data transformations before copying data - * to the device. + * This function asynchronously sets tensor data in the CANN backend. * * @param backend Pointer to the CANN backend structure. * @param tensor Pointer to the tensor structure to set data for. @@ -1536,23 +1837,28 @@ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend, size_t size) { ggml_backend_cann_context *cann_ctx = (ggml_backend_cann_context *)backend->context; + ggml_backend_buffer_t buf = + tensor->view_src ? tensor->view_src->buffer : tensor->buffer; - if (!need_transform(tensor->type)) { - ACL_CHECK(aclrtMemcpyAsync((char *)tensor->data + offset, size, data, - size, ACL_MEMCPY_HOST_TO_DEVICE, - cann_ctx->stream())); - } else { - void *transform_buffer = malloc(size); - ggml_backend_cann_transform(tensor, data, transform_buffer); + GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && + "unsupported buffer type"); + GGML_ASSERT(!ggml_is_quantized(tensor->type)); - ACL_CHECK(aclrtMemcpyAsync( - (char *)tensor->data + offset, size, transform_buffer, size, - ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream())); - ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream())); - free(transform_buffer); - } + ggml_cann_async_memcpy(cann_ctx, (char *)tensor->data + offset, data, size, + ACL_MEMCPY_HOST_TO_DEVICE); } +/** + * @brief Gets tensor data asynchronously in the CANN backend. + * + * This function asynchronously gets tensor data in the CANN backend. + * + * @param backend Pointer to the CANN backend structure. + * @param tensor Pointer to the tensor structure to get data from. + * @param data Pointer to the host data to copy from the tensor. + * @param offset Offset in bytes within the host data. + * @param size Size of the data to copy in bytes. + */ static void ggml_backend_cann_get_tensor_async( ggml_backend_t backend, const ggml_tensor *tensor, void *data, size_t offset, size_t size) { @@ -1563,20 +1869,11 @@ static void ggml_backend_cann_get_tensor_async( GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type"); + GGML_ASSERT(!ggml_is_quantized(tensor->type)); + + ggml_cann_async_memcpy(cann_ctx, data, (char *)tensor->data + offset, size, + ACL_MEMCPY_DEVICE_TO_HOST); - if (!need_transform(tensor->type)) { - ACL_CHECK(aclrtMemcpyAsync(data, size, (char *)tensor->data + offset, - size, ACL_MEMCPY_DEVICE_TO_HOST, - cann_ctx->stream())); - } else { - void *transform_buffer = malloc(size); - ACL_CHECK(aclrtMemcpyAsync( - transform_buffer, size, (char *)tensor->data + offset, size, - ACL_MEMCPY_DEVICE_TO_HOST, cann_ctx->stream())); - ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream())); - ggml_backend_cann_transform_back(tensor, transform_buffer, data); - free(transform_buffer); - } } /** @@ -1636,6 +1933,8 @@ static bool ggml_backend_cann_cpy_tensor_async( ggml_cann_set_device(cann_ctx_src->device); ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0)); + // wait for task_queue empty to keep task order. + cann_ctx_src->task_queue.wait(); ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_MEMCPY_DEVICE_TO_DEVICE, cann_ctx_src->stream())); @@ -1663,9 +1962,8 @@ static bool ggml_backend_cann_cpy_tensor_async( static void ggml_backend_cann_synchronize(ggml_backend_t backend) { ggml_backend_cann_context* cann_ctx = (ggml_backend_cann_context*)backend->context; - + cann_ctx->task_queue.wait(); ggml_cann_set_device(cann_ctx->device); - ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream())); } @@ -1749,6 +2047,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, return true; case GGML_TYPE_Q8_0: case GGML_TYPE_Q4_0: +#ifdef ASCEND_310P + // Q4 && Q8 per group is not suppor on 310p device + return false; +#endif // only support contiguous for quantized types. return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); @@ -1757,7 +2059,22 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, } } case GGML_OP_MUL_MAT_ID: - return false; + switch (op->src[0]->type) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + return true; + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: +#ifdef ASCEND_310P + // Q4 && Q8 per group is not suppor on 310p device + return false; +#endif + // only support contiguous for quantized types. + return ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op->src[1]); + default: + return false; + } // embedding case GGML_OP_GET_ROWS: { switch (op->src[0]->type) { @@ -1816,6 +2133,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, return false; } + if(!ggml_is_contiguous(op->src[0])){ + return false; + } return true; } case GGML_OP_UPSCALE: { @@ -1831,6 +2151,12 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, } case GGML_OP_POOL_2D: { const int32_t * opts = (const int32_t *) op->op_params; +#ifdef ASCEND_310P + enum ggml_op_pool opt = static_cast(opts[0]); + if(opt == GGML_OP_POOL_MAX){ + return false; + } +#endif const int k0 = opts[1]; const int k1 = opts[2]; const int p0 = opts[5]; @@ -1879,6 +2205,38 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_OP_PAD_REFLECT_1D: case GGML_OP_COUNT_EQUAL: return true; + case GGML_OP_FLASH_ATTN_EXT:{ + // derived from [ggml-cuda.cu] + if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){ + return false; + } + if(op->src[1]->type != GGML_TYPE_F16 && op->src[1]->type != GGML_TYPE_F32 && op->src[1]->type != GGML_TYPE_BF16){ + return false; + } + if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){ + return false; + } + if (op->src[1]->ne[0] != op->src[2]->ne[0]) { + // different head sizes of K and V are not supported yet + return false; + } + if (op->src[0]->ne[0] == 192) { + return false; + } + if (op->src[0]->ne[0] == 576) { + // DeepSeek MLA + return false; + } + if (op->src[0]->ne[3] != 1) { + return false; + } + float logitSoftcap = 0.0f; + memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float)); + if(logitSoftcap != 0.0f) { + return false; + } + return true; + } default: return false; } diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 086c822d7..fbb04426a 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -1074,6 +1074,10 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512) 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101, GGML_TABLE_END() +GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16) + -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113, +GGML_TABLE_END() + #define NGRID_IQ1S 2048 #define IQ1S_DELTA 0.125f #define IQ1M_DELTA 0.125f diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index e73a3b69b..df0034057 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -1,3 +1,17 @@ +function(ggml_add_cpu_backend_features cpu_name arch) + # The feature detection code is compiled as a separate target so that + # it can be built without the architecture flags + # Since multiple variants of the CPU backend may be included in the same + # build, using set_source_files_properties() to set the arch flags is not possible + set(GGML_CPU_FEATS_NAME ${cpu_name}-feats) + add_library(${GGML_CPU_FEATS_NAME} OBJECT ggml-cpu/arch/${arch}/cpu-feats.cpp) + target_include_directories(${GGML_CPU_FEATS_NAME} PRIVATE . .. ../include) + target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE ${ARGN}) + target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE GGML_BACKEND_DL GGML_BACKEND_BUILD GGML_BACKEND_SHARED) + set_target_properties(${GGML_CPU_FEATS_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_link_libraries(${cpu_name} PRIVATE ${GGML_CPU_FEATS_NAME}) +endfunction() + function(ggml_add_cpu_backend_variant_impl tag_name) if (tag_name) set(GGML_CPU_NAME ggml-cpu-${tag_name}) @@ -10,14 +24,14 @@ function(ggml_add_cpu_backend_variant_impl tag_name) list (APPEND GGML_CPU_SOURCES ggml-cpu/ggml-cpu.c ggml-cpu/ggml-cpu.cpp - ggml-cpu/ggml-cpu-aarch64.cpp - ggml-cpu/ggml-cpu-aarch64.h - ggml-cpu/ggml-cpu-hbm.cpp - ggml-cpu/ggml-cpu-hbm.h - ggml-cpu/ggml-cpu-quants.c - ggml-cpu/ggml-cpu-quants.h - ggml-cpu/ggml-cpu-traits.cpp - ggml-cpu/ggml-cpu-traits.h + ggml-cpu/repack.cpp + ggml-cpu/repack.h + ggml-cpu/hbm.cpp + ggml-cpu/hbm.h + ggml-cpu/quants.c + ggml-cpu/quants.h + ggml-cpu/traits.cpp + ggml-cpu/traits.h ggml-cpu/amx/amx.cpp ggml-cpu/amx/amx.h ggml-cpu/amx/mmq.cpp @@ -82,12 +96,12 @@ function(ggml_add_cpu_backend_variant_impl tag_name) target_link_libraries(${GGML_CPU_NAME} PUBLIC memkind) endif() - if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR - CMAKE_GENERATOR_PLATFORM_LWR STREQUAL "arm64" OR - (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND - CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm.*|ARM64)$")) - + if (GGML_SYSTEM_ARCH STREQUAL "ARM") message(STATUS "ARM detected") + list(APPEND GGML_CPU_SOURCES + ggml-cpu/arch/arm/quants.c + ggml-cpu/arch/arm/repack.cpp + ) if (MSVC AND NOT CMAKE_C_COMPILER_ID STREQUAL "Clang") message(FATAL_ERROR "MSVC is not supported for ARM, use clang") @@ -143,6 +157,49 @@ function(ggml_add_cpu_backend_variant_impl tag_name) else() if (GGML_CPU_ARM_ARCH) list(APPEND ARCH_FLAGS -march=${GGML_CPU_ARM_ARCH}) + elseif(GGML_CPU_ALL_VARIANTS) + # Begin with the lowest baseline + set(ARM_MCPU "armv8-a") + set(ARCH_TAGS "") + set(ARCH_DEFINITIONS "") + + # When a feature is selected, bump the MCPU to the first + # version that supported it + if (GGML_INTERNAL_DOTPROD) + set(ARM_MCPU "armv8.2-a") + set(ARCH_TAGS "${ARCH_TAGS}+dotprod") + list(APPEND ARCH_DEFINITIONS GGML_USE_DOTPROD) + endif() + if (GGML_INTERNAL_FP16_VECTOR_ARITHMETIC) + set(ARM_MCPU "armv8.2-a") + set(ARCH_TAGS "${ARCH_TAGS}+fp16") + list(APPEND ARCH_DEFINITIONS GGML_USE_FP16_VECTOR_ARITHMETIC) + endif() + if (GGML_INTERNAL_SVE) + set(ARM_MCPU "armv8.2-a") + set(ARCH_TAGS "${ARCH_TAGS}+sve") + list(APPEND ARCH_DEFINITIONS GGML_USE_SVE) + endif() + if (GGML_INTERNAL_MATMUL_INT8) + set(ARM_MCPU "armv8.6-a") + set(ARCH_TAGS "${ARCH_TAGS}+i8mm") + list(APPEND ARCH_DEFINITIONS GGML_USE_MATMUL_INT8) + endif() + if (GGML_INTERNAL_SVE2) + set(ARM_MCPU "armv8.6-a") + set(ARCH_TAGS "${ARCH_TAGS}+sve2") + list(APPEND ARCH_DEFINITIONS GGML_USE_SVE2) + endif() + if (GGML_INTERNAL_NOSVE) + set(ARCH_TAGS "${ARCH_TAGS}+nosve") + endif() + if (GGML_INTERNAL_SME) + set(ARM_MCPU "armv9.2-a") + set(ARCH_TAGS "${ARCH_TAGS}+sme") + list(APPEND ARCH_DEFINITIONS GGML_USE_SME) + endif() + list(APPEND ARCH_FLAGS "-march=${ARM_MCPU}${ARCH_TAGS}") + ggml_add_cpu_backend_features(${GGML_CPU_NAME} arm ${ARCH_DEFINITIONS}) endif() endif() @@ -170,11 +227,12 @@ function(ggml_add_cpu_backend_variant_impl tag_name) endforeach() endif() endif() - elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR - (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND - CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64|amd64)$")) - + elseif (GGML_SYSTEM_ARCH STREQUAL "x86") message(STATUS "x86 detected") + list(APPEND GGML_CPU_SOURCES + ggml-cpu/arch/x86/quants.c + ggml-cpu/arch/x86/repack.cpp + ) if (MSVC) # instruction set detection for MSVC only @@ -222,7 +280,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) elseif (GGML_AVX) list(APPEND ARCH_FLAGS /arch:AVX) list(APPEND ARCH_DEFINITIONS GGML_AVX) - else () + elseif (GGML_SSE42) list(APPEND ARCH_FLAGS /arch:SSE4.2) list(APPEND ARCH_DEFINITIONS GGML_SSE42) endif() @@ -237,8 +295,10 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_NATIVE) list(APPEND ARCH_FLAGS -march=native) else () - list(APPEND ARCH_FLAGS -msse4.2) - list(APPEND ARCH_DEFINITIONS GGML_SSE42) + if (GGML_SSE42) + list(APPEND ARCH_FLAGS -msse4.2) + list(APPEND ARCH_DEFINITIONS GGML_SSE42) + endif() if (GGML_F16C) list(APPEND ARCH_FLAGS -mf16c) list(APPEND ARCH_DEFINITIONS GGML_F16C) @@ -297,8 +357,17 @@ function(ggml_add_cpu_backend_variant_impl tag_name) endif() endif() endif() - elseif ("${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "ppc64le " OR "${CMAKE_SYSTEM_PROCESSOR} " STREQUAL "powerpc ") + + if (GGML_BACKEND_DL) + if (GGML_NATIVE) + # the feature check relies on ARCH_DEFINITIONS, but it is not set with GGML_NATIVE + message(FATAL_ERROR "GGML_NATIVE is not compatible with GGML_BACKEND_DL, consider using GGML_CPU_ALL_VARIANTS") + endif() + ggml_add_cpu_backend_features(${GGML_CPU_NAME} x86 ${ARCH_DEFINITIONS}) + endif() + elseif (GGML_SYSTEM_ARCH STREQUAL "PowerPC") message(STATUS "PowerPC detected") + list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/powerpc/quants.c) if (GGML_NATIVE) if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") file(READ "/proc/cpuinfo" POWER10_M) @@ -306,7 +375,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name) execute_process(COMMAND bash -c "prtconf |grep 'Implementation' | head -n 1" OUTPUT_VARIABLE POWER10_M) endif() - string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M}") + string(TOUPPER "${POWER10_M}" POWER10_M_UPPER) + string(REGEX MATCHALL "POWER *([0-9]+)" MATCHED_STRING "${POWER10_M_UPPER}") string(REGEX REPLACE "POWER *([0-9]+)" "\\1" EXTRACTED_NUMBER "${MATCHED_STRING}") if (EXTRACTED_NUMBER GREATER_EQUAL 10) @@ -323,8 +393,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) list(APPEND ARCH_FLAGS -mcpu=${GGML_CPU_POWERPC_CPUTYPE}) endif() endif() - elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64") + elseif (GGML_SYSTEM_ARCH STREQUAL "loongarch64") message(STATUS "loongarch64 detected") + list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/loongarch/quants.c) list(APPEND ARCH_FLAGS -march=loongarch64) if (GGML_LASX) @@ -333,27 +404,38 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_LSX) list(APPEND ARCH_FLAGS -mlsx) endif() - elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "riscv64") - message(STATUS "RISC-V detected") + elseif (GGML_SYSTEM_ARCH STREQUAL "riscv64") + message(STATUS "riscv64 detected") + list(APPEND GGML_CPU_SOURCES + ggml-cpu/arch/riscv/quants.c + ggml-cpu/arch/riscv/repack.cpp + ) if (GGML_RVV) - if (GGML_RV_ZFH) - list(APPEND ARCH_FLAGS -march=rv64gcv_zfhmin -DGGML_RV_ZFH -mabi=lp64d) + if (GGML_XTHEADVECTOR) + list(APPEND ARCH_FLAGS -march=rv64gc_xtheadvector -mabi=lp64d) + elseif (GGML_RV_ZFH) + list(APPEND ARCH_FLAGS -march=rv64gcv_zfhmin -mabi=lp64d) else() list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d) endif() endif() - elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x") + elseif (GGML_SYSTEM_ARCH STREQUAL "s390x") message(STATUS "s390x detected") + list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/s390/quants.c) file(READ "/proc/cpuinfo" CPUINFO_CONTENTS) string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS}) # TODO: Separation to determine activation of VX/VXE/VXE2 if (${S390X_M} MATCHES "8561|8562") message(STATUS "z15 target") - list(APPEND ARCH_FLAGS -march=z15 -mtune=z15) + list(APPEND ARCH_FLAGS -march=z15) elseif (${S390X_M} MATCHES "3931") message(STATUS "z16 target") - list(APPEND ARCH_FLAGS -march=z16 -mtune=z16) + list(APPEND ARCH_FLAGS -march=z16) + elseif (${S390X_M} MATCHES "9175|9176") + # NOTE: Only available from GCC 15.1.0 onwards. Any z17 machine with compile issues must first verify their GCC version. + message(STATUS "z17 target") + list(APPEND ARCH_FLAGS -march=z17) else() message(STATUS "Unknown target") message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.") @@ -363,12 +445,16 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_VXE) list(APPEND ARCH_FLAGS -mvx -mzvector) endif() + elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "wasm") + message(STATUS "Wasm detected") + list (APPEND GGML_CPU_SOURCES ggml-cpu/arch/wasm/quants.c) else() - message(STATUS "Unknown architecture") + message(WARNING "Unknown CPU architecture. Falling back to generic implementations.") + list(APPEND ARCH_FLAGS -DGGML_CPU_GENERIC) endif() - if (GGML_CPU_AARCH64) - target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_AARCH64) + if (GGML_CPU_REPACK) + target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_REPACK) endif() if (GGML_CPU_KLEIDIAI) @@ -379,9 +465,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) # Fetch KleidiAI sources: include(FetchContent) - set(KLEIDIAI_COMMIT_TAG "v1.5.0") + set(KLEIDIAI_COMMIT_TAG "v1.6.0") set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") - set(KLEIDIAI_ARCHIVE_MD5 "ea22e1aefb800e9bc8c74d91633cc58e") + set(KLEIDIAI_ARCHIVE_MD5 "75b4ad68f25ab673dcc01065e5a0b05f") if (POLICY CMP0135) cmake_policy(SET CMP0135 NEW) @@ -422,6 +508,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ${KLEIDIAI_SRC}/kai/ukernels/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/ + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/) set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}") @@ -432,17 +519,19 @@ function(ggml_add_cpu_backend_variant_impl tag_name) string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED) string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED) - set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS}) + set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS_TEMP}) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c) + list(APPEND GGML_KLEIDIAI_SOURCES + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c) if (NOT DOTPROD_ENABLED MATCHES -1) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c) + list(APPEND GGML_KLEIDIAI_SOURCES + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c) endif() if (NOT I8MM_ENABLED MATCHES -1) @@ -450,9 +539,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name) endif() if (NOT SME_ENABLED MATCHES -1) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c) - list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c) - set(PRIVATE_ARCH_FLAGS "${PRIVATE_ARCH_FLAGS}+sve+sve2") + list(APPEND GGML_KLEIDIAI_SOURCES + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c) + set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2") endif() set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}") @@ -464,25 +557,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name) target_compile_options(${GGML_CPU_NAME} PRIVATE ${ARCH_FLAGS}) target_compile_definitions(${GGML_CPU_NAME} PRIVATE ${ARCH_DEFINITIONS}) - if (GGML_BACKEND_DL) - if (GGML_NATIVE) - # the feature check relies on ARCH_DEFINITIONS, but it is not set with GGML_NATIVE - message(FATAL_ERROR "GGML_NATIVE is not compatible with GGML_BACKEND_DL, consider using GGML_CPU_ALL_VARIANTS") - endif() - - # The feature detection code is compiled as a separate target so that - # it can be built without the architecture flags - # Since multiple variants of the CPU backend may be included in the same - # build, using set_source_files_properties() to set the arch flags is not possible - set(GGML_CPU_FEATS_NAME ${GGML_CPU_NAME}-feats) - add_library(${GGML_CPU_FEATS_NAME} OBJECT ggml-cpu/cpu-feats-x86.cpp) - target_include_directories(${GGML_CPU_FEATS_NAME} PRIVATE . .. ../include) - target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE ${ARCH_DEFINITIONS}) - target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE GGML_BACKEND_DL GGML_BACKEND_BUILD GGML_BACKEND_SHARED) - set_target_properties(${GGML_CPU_FEATS_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) - target_link_libraries(${GGML_CPU_NAME} PRIVATE ${GGML_CPU_FEATS_NAME}) - endif() - if (EMSCRIPTEN) set_target_properties(${GGML_CPU_NAME} PROPERTIES COMPILE_FLAGS "-msimd128") endif() diff --git a/ggml/src/ggml-cpu/amx/amx.cpp b/ggml/src/ggml-cpu/amx/amx.cpp index 0f067137d..258857b00 100644 --- a/ggml/src/ggml-cpu/amx/amx.cpp +++ b/ggml/src/ggml-cpu/amx/amx.cpp @@ -5,7 +5,7 @@ #include "ggml-backend.h" #include "ggml-impl.h" #include "ggml-cpu.h" -#include "ggml-cpu-traits.h" +#include "traits.h" #if defined(__gnu_linux__) #include diff --git a/ggml/src/ggml-cpu/amx/mmq.cpp b/ggml/src/ggml-cpu/amx/mmq.cpp index 0ea91596b..cec34eb64 100644 --- a/ggml/src/ggml-cpu/amx/mmq.cpp +++ b/ggml/src/ggml-cpu/amx/mmq.cpp @@ -8,7 +8,7 @@ #include "mmq.h" #include "ggml-impl.h" #include "ggml-cpu-impl.h" -#include "ggml-cpu-quants.h" +#include "quants.h" #include "ggml-quants.h" #include #include diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h new file mode 100644 index 000000000..10e534251 --- /dev/null +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -0,0 +1,184 @@ +#pragma once + +// Rename `_generic` functions if no native implementation is available. +// This effectively selects the generic implementation. + +#if defined(GGML_CPU_GENERIC) +// quants.c +#define quantize_row_q8_0_generic quantize_row_q8_0 +#define quantize_row_q8_1_generic quantize_row_q8_1 +#define quantize_row_q8_K_generic quantize_row_q8_K +#define ggml_vec_dot_q4_0_q8_0_generic ggml_vec_dot_q4_0_q8_0 +#define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1 +#define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0 +#define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1 +#define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0 +#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K +#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K +#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K +#define ggml_vec_dot_q3_K_q8_K_generic ggml_vec_dot_q3_K_q8_K +#define ggml_vec_dot_q4_K_q8_K_generic ggml_vec_dot_q4_K_q8_K +#define ggml_vec_dot_q5_K_q8_K_generic ggml_vec_dot_q5_K_q8_K +#define ggml_vec_dot_q6_K_q8_K_generic ggml_vec_dot_q6_K_q8_K +#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K +#define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K +#define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K +#define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K +#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K +#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K +#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K +#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0 +#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K +// repack.cpp +#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 +#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 +#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 +#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 +#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 +#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 +#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64) +// repack.cpp +#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 +#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) +// repack.cpp +#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 +#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 +#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 +#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 +#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 +#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#elif defined(__POWERPC__) || defined(__powerpc__) +// ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679 +// quants.c +#define quantize_row_q8_K_generic quantize_row_q8_K +#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K +#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K +#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K +// repack.cpp +#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 +#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 +#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 +#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 +#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 +#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 +#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#elif defined(__loongarch64) +// quants.c +#define quantize_row_q8_K_generic quantize_row_q8_K +#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K +#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K +#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K +// repack.cpp +#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 +#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 +#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 +#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 +#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 +#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 +#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#elif defined(__riscv) +// quants.c +#define quantize_row_q8_K_generic quantize_row_q8_K +#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K +#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K +#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K +#define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K +#define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K +#define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K +#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K +#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K +#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K +#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0 +#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K +// repack.cpp +#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 +#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 +#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 +#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 +#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 +#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 +#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#elif defined(__s390x__) +// quants.c +#define quantize_row_q8_K_generic quantize_row_q8_K +#define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0 +#define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1 +#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K +#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K +#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K +#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K +#define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K +#define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K +#define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K +#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K +#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K +#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K +// repack.cpp +#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 +#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 +#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 +#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 +#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 +#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 +#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#elif defined(__wasm__) +// quants.c +#define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1 +#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K +#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K +#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K +#define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K +#define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K +#define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K +#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K +#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K +#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K +#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0 +#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K +// repack.cpp +#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 +#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 +#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 +#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 +#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 +#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 +#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#endif diff --git a/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp b/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp new file mode 100644 index 000000000..67369147c --- /dev/null +++ b/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp @@ -0,0 +1,94 @@ +#include "ggml-backend-impl.h" + +#if defined(__aarch64__) + +#if defined(__linux__) +#include +#elif defined(__APPLE__) +#include +#endif + +#if !defined(HWCAP2_I8MM) +#define HWCAP2_I8MM (1 << 13) +#endif + +#if !defined(HWCAP2_SME) +#define HWCAP2_SME (1 << 23) +#endif + +struct aarch64_features { + // has_neon not needed, aarch64 has NEON guaranteed + bool has_dotprod = false; + bool has_fp16_va = false; + bool has_sve = false; + bool has_sve2 = false; + bool has_i8mm = false; + bool has_sme = false; + + aarch64_features() { +#if defined(__linux__) + uint32_t hwcap = getauxval(AT_HWCAP); + uint32_t hwcap2 = getauxval(AT_HWCAP2); + + has_dotprod = !!(hwcap & HWCAP_ASIMDDP); + has_fp16_va = !!(hwcap & HWCAP_FPHP); + has_sve = !!(hwcap & HWCAP_SVE); + has_sve2 = !!(hwcap2 & HWCAP2_SVE2); + has_i8mm = !!(hwcap2 & HWCAP2_I8MM); + has_sme = !!(hwcap2 & HWCAP2_SME); +#elif defined(__APPLE__) + int oldp = 0; + size_t size = sizeof(oldp); + + if (sysctlbyname("hw.optional.arm.FEAT_DotProd", &oldp, &size, NULL, 0) == 0) { + has_dotprod = static_cast(oldp); + } + + if (sysctlbyname("hw.optional.arm.FEAT_I8MM", &oldp, &size, NULL, 0) == 0) { + has_i8mm = static_cast(oldp); + } + + if (sysctlbyname("hw.optional.arm.FEAT_SME", &oldp, &size, NULL, 0) == 0) { + has_sme = static_cast(oldp); + } + + // Apple apparently does not implement SVE yet +#endif + } +}; + +static int ggml_backend_cpu_aarch64_score() { + int score = 1; + aarch64_features af; + +#ifdef GGML_USE_DOTPROD + if (!af.has_dotprod) { return 0; } + score += 1<<1; +#endif +#ifdef GGML_USE_FP16_VECTOR_ARITHMETIC + if (!af.has_fp16_va) { return 0; } + score += 1<<2; +#endif +#ifdef GGML_USE_SVE + if (!af.has_sve) { return 0; } + score += 1<<3; +#endif +#ifdef GGML_USE_MATMUL_INT8 + if (!af.has_i8mm) { return 0; } + score += 1<<4; +#endif +#ifdef GGML_USE_SVE2 + if (!af.has_sve2) { return 0; } + score += 1<<5; +#endif +#ifdef GGML_USE_SME + if (!af.has_sme) { return 0; } + score += 1<<6; +#endif + + return score; +} + +GGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cpu_aarch64_score) + +# endif // defined(__aarch64__) diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c new file mode 100644 index 000000000..b0909dac0 --- /dev/null +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -0,0 +1,4113 @@ +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" +#include "ggml-quants.h" +#include "ggml-impl.h" +#include "ggml-cpu.h" + +#include "../../quants.h" +#include "../../ggml-cpu-impl.h" + +#include +#include +#include +#include +#include // for qsort +#include // for GGML_ASSERT + +#define GROUP_MAX_EPS 1e-15f +#define GROUP_MAX_EPS_IQ3_XXS 1e-8f +#define GROUP_MAX_EPS_IQ2_S 1e-8f +#define GROUP_MAX_EPS_IQ1_M 1e-7f +#define GROUP_MAX_EPS_IQ1_S 1e-12f + +#define UNUSED GGML_UNUSED + +#if defined(__ARM_NEON) +#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s +#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) +#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s) +#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s) +#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s) +#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s) +#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s) +#define B8(c,s ) B7(c,s, c), B7(c,s, s) + +// precomputed tables for expanding 8bits to 8 bytes: +static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4 +static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 +#endif + +void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0 * GGML_RESTRICT y = vy; + +#if defined(__ARM_NEON) + for (int i = 0; i < nb; i++) { + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + for (int j = 0; j < 8; j++) { + const float32x4_t v = vmulq_n_f32(srcv[j], id); + const int32x4_t vi = vcvtnq_s32_f32(v); + + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + } + } +#else + GGML_UNUSED(nb); + // scalar + quantize_row_q8_0_ref(x, y, k); +#endif +} + +void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % QK8_1 == 0); + const int nb = k / QK8_1; + + block_q8_1 * GGML_RESTRICT y = vy; +#if defined(__ARM_NEON) + for (int i = 0; i < nb; i++) { + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + int32x4_t accv = vdupq_n_s32(0); + + for (int j = 0; j < 8; j++) { + const float32x4_t v = vmulq_n_f32(srcv[j], id); + const int32x4_t vi = vcvtnq_s32_f32(v); + + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + + accv = vaddq_s32(accv, vi); + } + + y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); + } +#else + GGML_UNUSED(nb); + // scalar + quantize_row_q8_1_ref(x, y, k); +#endif +} + +// placeholder implementation for Apple targets +void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + quantize_row_q8_K_ref(x, y, k); +} + +//===================================== Dot products ================================= + +void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); +#if defined(__ARM_FEATURE_MATMUL_INT8) + assert((nrc == 2) || (nrc == 1)); +#else + assert(nrc == 1); +#endif + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + +#if defined(__ARM_FEATURE_MATMUL_INT8) + if (nrc == 2) { + const block_q4_0 * GGML_RESTRICT vx0 = vx; + const block_q4_0 * GGML_RESTRICT vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx); + const block_q8_0 * GGML_RESTRICT vy0 = vy; + const block_q8_0 * GGML_RESTRICT vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by); + + float32x4_t sumv0 = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; i++) { + const block_q4_0 * GGML_RESTRICT b_x0 = &vx0[i]; + const block_q4_0 * GGML_RESTRICT b_x1 = &vx1[i]; + const block_q8_0 * GGML_RESTRICT b_y0 = &vy0[i]; + const block_q8_0 * GGML_RESTRICT b_y1 = &vy1[i]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + const int8x16_t s8b = vdupq_n_s8(0x8); + + const uint8x16_t v0_0 = vld1q_u8(b_x0->qs); + const uint8x16_t v0_1 = vld1q_u8(b_x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // sub 8 + const int8x16_t x0_l = vsubq_s8(v0_0l, s8b); + const int8x16_t x0_h = vsubq_s8(v0_0h, s8b); + const int8x16_t x1_l = vsubq_s8(v0_1l, s8b); + const int8x16_t x1_h = vsubq_s8(v0_1h, s8b); + + // load y + const int8x16_t y0_l = vld1q_s8(b_y0->qs); + const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); + const int8x16_t y1_l = vld1q_s8(b_y1->qs); + const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); + + float32_t _scale[4] = { + GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d) + }; + float32x4_t scale = vld1q_f32(_scale); + + int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + + int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + + int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + + int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + + sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), + l1, r1)), l2, r2)), l3, r3))), scale); + } + + float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2); + float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); + + vst1_f32(s, vget_low_f32 (sumv2)); + vst1_f32(s + bs, vget_high_f32(sumv2)); + + return; + } +#endif + + int ib = 0; + float sumf = 0; + +#if defined(__ARM_FEATURE_SVE) + svfloat32_t sumv0 = svdup_n_f32(0.0f); + svfloat32_t sumv1 = svdup_n_f32(0.0f); + + const int vector_length = ggml_cpu_get_sve_cnt()*8; + + // VLA Implementation using switch case + switch (vector_length) { + case 128: + { + // predicate for activating higher lanes for 4 float32 elements + const svbool_t ph4 = svptrue_pat_b32(SV_VL4); + + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0]; + const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1]; + const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; + const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; + + // load x + const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); + const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); + + // 4-bit -> 8-bit + const svint8_t qx0l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx0r, 0x0F)); + const svint8_t qx0h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx0r, 0x04)); + const svint8_t qx1l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx1r, 0x0F)); + const svint8_t qx1h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx1r, 0x04)); + + // sub 8 + const svint8_t qx0ls = svsub_n_s8_x(svptrue_b8(), qx0h, 8); + const svint8_t qx0hs = svsub_n_s8_x(svptrue_b8(), qx0l, 8); + const svint8_t qx1ls = svsub_n_s8_x(svptrue_b8(), qx1h, 8); + const svint8_t qx1hs = svsub_n_s8_x(svptrue_b8(), qx1l, 8); + + // load y + const svint8_t qy0h = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy0l = svld1_s8(svptrue_b8(), y0->qs + 16); + const svint8_t qy1h = svld1_s8(svptrue_b8(), y1->qs); + const svint8_t qy1l = svld1_s8(svptrue_b8(), y1->qs + 16); + + // dot product + sumv0 = svmla_n_f32_x(ph4, sumv0, svcvt_f32_s32_x(ph4, svadd_x(ph4, + svdot_s32(svdup_n_s32(0), qx0ls, qy0l), + svdot_s32(svdup_n_s32(0), qx0hs, qy0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(ph4, sumv1, svcvt_f32_s32_x(ph4, svadd_x(ph4, + svdot_s32(svdup_n_s32(0), qx1ls, qy1l), + svdot_s32(svdup_n_s32(0), qx1hs, qy1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + } break; + case 256: + { + // predicate for activating higher lanes for 16 int8 elements + const svbool_t ph16 = svptrue_pat_b8(SV_VL16); + // predicate for activating lower lanes for 16 int8 elements + const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16); + + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0]; + const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1]; + const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; + const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; + + // load x + const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); + const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); + + // 4-bit -> 8-bit + const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04)); + const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04)); + + // sub 8 + const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8); + const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8); + + // load y + const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); + + // dot product + sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), + svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), + svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + } break; + case 512: + { + // predicate for activating higher lanes for 32 int8 elements + const svbool_t ph32 = svptrue_pat_b8(SV_VL32); + + // predicate for activating higher lanes for 16 int8 elements + const svbool_t ph16 = svptrue_pat_b8(SV_VL16); + // predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes + const svbool_t pl16 = svnot_b_z(ph32, ph16); + + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0]; + const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1]; + const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; + const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; + + // load x + const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs); + const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs); + + // 4-bit -> 8-bit + const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04)); + const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04)); + + // sub 8 + const svint8_t qx0s = svsub_n_s8_x(ph32, qx0, 8); + const svint8_t qx1s = svsub_n_s8_x(ph32, qx1, 8); + + // load y + const svint8_t qy0 = svld1_s8(ph32, y0->qs); + const svint8_t qy1 = svld1_s8(ph32, y1->qs); + + // dot product + sumv0 = svmla_n_f32_x(ph32, sumv0, svcvt_f32_s32_x(ph32, + svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(ph32, sumv1, svcvt_f32_s32_x(ph32, + svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1)); + } break; + default: + assert(false && "Unsupported vector length"); + break; + } + +#elif defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0]; + const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1]; + const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; + const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + const int8x16_t s8b = vdupq_n_s8(0x8); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // sub 8 + const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); + const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); + const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); + const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + + // dot product into int32x4_t + const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h); + const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#endif + for (; ib < nb; ++ib) { + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[ib].qs[j] & 0x0F) - 8; + const int v1 = (x[ib].qs[j] >> 4) - 8; + + sumi0 += (v0 * y[ib].qs[j]); + sumi1 += (v1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d); + } + + *s = sumf; +} + +void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; + + assert(n % qk == 0); +#if defined(__ARM_FEATURE_MATMUL_INT8) + assert((nrc == 2) || (nrc == 1)); +#else + assert(nrc == 1); +#endif + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_1 * GGML_RESTRICT x = vx; + const block_q8_1 * GGML_RESTRICT y = vy; + +#if defined(__ARM_FEATURE_MATMUL_INT8) + if (nrc == 2) { + const block_q4_1 * GGML_RESTRICT vx0 = vx; + const block_q4_1 * GGML_RESTRICT vx1 = (const block_q4_1 *) ((const uint8_t*)vx + bx); + const block_q8_1 * GGML_RESTRICT vy0 = vy; + const block_q8_1 * GGML_RESTRICT vy1 = (const block_q8_1 *) ((const uint8_t*)vy + by); + + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t summs0 = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; i++) { + const block_q4_1 * GGML_RESTRICT b_x0 = &vx0[i]; + const block_q4_1 * GGML_RESTRICT b_x1 = &vx1[i]; + const block_q8_1 * GGML_RESTRICT b_y0 = &vy0[i]; + const block_q8_1 * GGML_RESTRICT b_y1 = &vy1[i]; + + float32_t summs_t[4] = { + GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s), + GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s), + GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s), + GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s) + }; + summs0 = vaddq_f32(summs0, vld1q_f32(summs_t)); + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + const uint8x16_t v0_0 = vld1q_u8(b_x0->qs); + const uint8x16_t v0_1 = vld1q_u8(b_x1->qs); + + // 4-bit -> 8-bit + const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // load y + const int8x16_t y0_l = vld1q_s8(b_y0->qs); + const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); + const int8x16_t y1_l = vld1q_s8(b_y1->qs); + const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); + + // mmla into int32x4_t + float32_t _scale[4] = { + GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d) + }; + float32x4_t scale = vld1q_f32(_scale); + + int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + + int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + + int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + + int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), + l1, r1)), l2, r2)), l3, r3))), scale); + } + + float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2); + float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); + + sumv2 = vaddq_f32(sumv2, summs0); + + vst1_f32(s, vget_low_f32 (sumv2)); + vst1_f32(s + bs, vget_high_f32(sumv2)); + + return; + } +#endif + + int ib = 0; + float sumf = 0; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + float summs = 0; + + for (; ib + 1 < nb; ib += 2) { + const block_q4_1 * GGML_RESTRICT x0 = &x[ib + 0]; + const block_q4_1 * GGML_RESTRICT x1 = &x[ib + 1]; + const block_q8_1 * GGML_RESTRICT y0 = &y[ib + 0]; + const block_q8_1 * GGML_RESTRICT y1 = &y[ib + 1]; + + summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s) + GGML_FP16_TO_FP32(x1->m) * GGML_FP16_TO_FP32(y1->s); + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + + // dot product into int32x4_t + const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h); + const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; + +#endif + for (; ib < nb; ++ib) { + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[ib].qs[j] & 0x0F); + const int v1 = (x[ib].qs[j] >> 4); + + sumi0 += (v0 * y[ib].qs[j]); + sumi1 += (v1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); + } + + *s = sumf; +} + +void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + int ib = 0; + float sumf = 0; + + assert(n % qk == 0); + assert(qk == QK5_0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + uint32_t qh0; + uint32_t qh1; + + uint64_t tmp0[4]; + uint64_t tmp1[4]; + + for (; ib + 1 < nb; ib += 2) { + const block_q5_0 * GGML_RESTRICT x0 = &x[ib]; + const block_q5_0 * GGML_RESTRICT x1 = &x[ib + 1]; + const block_q8_0 * GGML_RESTRICT y0 = &y[ib]; + const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + // extract the 5th bit via lookup table ((!b) << 4) + memcpy(&qh0, x0->qh, sizeof(qh0)); + memcpy(&qh1, x1->qh, sizeof(qh1)); + + tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF]; + tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF]; + tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF]; + tmp0[3] = table_b2b_1[(qh0 >> 24) ]; + + tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF]; + tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF]; + tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF]; + tmp1[3] = table_b2b_1[(qh1 >> 24) ]; + + const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); + const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); + const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); + const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) + const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0); + const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0); + const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1); + const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), + ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), + ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); + +#endif + for (; ib < nb; ++ib) { + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); + + const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16); + const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16); + + sumi0 += (x0 * y[ib].qs[j]); + sumi1 += (x1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi; + } + + *s = sumf; +} + +void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; + + int ib = 0; + float sumf = 0; + + assert(n % qk == 0); + assert(qk == QK5_1); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_1 * GGML_RESTRICT x = vx; + const block_q8_1 * GGML_RESTRICT y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + float summs0 = 0.0f; + float summs1 = 0.0f; + + uint32_t qh0; + uint32_t qh1; + + uint64_t tmp0[4]; + uint64_t tmp1[4]; + + for (; ib + 1 < nb; ib += 2) { + const block_q5_1 * GGML_RESTRICT x0 = &x[ib]; + const block_q5_1 * GGML_RESTRICT x1 = &x[ib + 1]; + const block_q8_1 * GGML_RESTRICT y0 = &y[ib]; + const block_q8_1 * GGML_RESTRICT y1 = &y[ib + 1]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + summs0 += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s); + summs1 += GGML_FP16_TO_FP32(x1->m) * GGML_FP16_TO_FP32(y1->s); + + // extract the 5th bit via lookup table ((b) << 4) + memcpy(&qh0, x0->qh, sizeof(qh0)); + memcpy(&qh1, x1->qh, sizeof(qh1)); + + tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF]; + tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF]; + tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF]; + tmp0[3] = table_b2b_0[(qh0 >> 24) ]; + + tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF]; + tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF]; + tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF]; + tmp1[3] = table_b2b_0[(qh1 >> 24) ]; + + const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); + const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); + const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); + const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // add high bit + const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0); + const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0); + const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1); + const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), + ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), + ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; + +#endif + for (; ib < nb; ++ib) { + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; + + const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0; + const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1; + + sumi0 += (x0 * y[ib].qs[j]); + sumi1 += (x1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); + } + + *s = sumf; +} + +void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); +#if defined(__ARM_FEATURE_MATMUL_INT8) + assert((nrc == 2) || (nrc == 1)); +#else + assert(nrc == 1); +#endif + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q8_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + +#if defined(__ARM_FEATURE_MATMUL_INT8) + if (nrc == 2) { + const block_q8_0 * GGML_RESTRICT vx0 = vx; + const block_q8_0 * GGML_RESTRICT vx1 = (const block_q8_0 *) ((const uint8_t*)vx + bx); + const block_q8_0 * GGML_RESTRICT vy0 = vy; + const block_q8_0 * GGML_RESTRICT vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by); + + float32x4_t sumv0 = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; i++) { + const block_q8_0 * GGML_RESTRICT b_x0 = &vx0[i]; + const block_q8_0 * GGML_RESTRICT b_y0 = &vy0[i]; + + const block_q8_0 * GGML_RESTRICT b_x1 = &vx1[i]; + const block_q8_0 * GGML_RESTRICT b_y1 = &vy1[i]; + + const int8x16_t x0_l = vld1q_s8(b_x0->qs); + const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16); + const int8x16_t x1_l = vld1q_s8(b_x1->qs); + const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16); + + // load y + const int8x16_t y0_l = vld1q_s8(b_y0->qs); + const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); + const int8x16_t y1_l = vld1q_s8(b_y1->qs); + const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); + + float32_t _scale[4] = { + GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d), + GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d) + }; + float32x4_t scale = vld1q_f32(_scale); + + int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); + + int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); + + int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); + + int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); + + sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), + l1, r1)), l2, r2)), l3, r3))), scale); + } + + float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2); + float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); + + vst1_f32(s, vget_low_f32 (sumv2)); + vst1_f32(s + bs, vget_high_f32(sumv2)); + + return; + } +#endif + + int ib = 0; + float sumf = 0; + +#if defined(__ARM_FEATURE_SVE) + svfloat32_t sumv0 = svdup_n_f32(0.0f); + svfloat32_t sumv1 = svdup_n_f32(0.0f); + + const int vector_length = ggml_cpu_get_sve_cnt()*8; + + //VLA Implemenation for SVE + switch (vector_length) { + case 128: + { + // predicate for activating lanes for 16 Int8 elements + const svbool_t ph16 = svptrue_pat_b8 (SV_VL16); + const svbool_t pl16 = svptrue_pat_b32(SV_VL4); + + for (; ib + 1 < nb; ib += 2) { + const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0]; + const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1]; + const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; + const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; + + // load x + const svint8_t qx0_0 = svld1_s8(ph16, x0->qs); + const svint8_t qx0_1 = svld1_s8(ph16, x0->qs+16); + const svint8_t qx1_0 = svld1_s8(ph16, x1->qs); + const svint8_t qx1_1 = svld1_s8(ph16, x1->qs+16); + + // load y + const svint8_t qy0_0 = svld1_s8(ph16, y0->qs); + const svint8_t qy0_1 = svld1_s8(ph16, y0->qs+16); + const svint8_t qy1_0 = svld1_s8(ph16, y1->qs); + const svint8_t qy1_1 = svld1_s8(ph16, y1->qs+16); + + sumv0 = svmla_n_f32_x(pl16, sumv0, svcvt_f32_s32_x(pl16, svadd_x(pl16, + svdot_s32(svdup_n_s32(0), qx0_0, qy0_0), + svdot_s32(svdup_n_s32(0), qx0_1, qy0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(pl16, sumv1, svcvt_f32_s32_x(pl16, svadd_x(pl16, + svdot_s32(svdup_n_s32(0), qx1_0, qy1_0), + svdot_s32(svdup_n_s32(0), qx1_1, qy1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + sumf = svaddv_f32(pl16, svadd_f32_x(pl16, sumv0, sumv1)); + } break; + case 256: + { + //printf("sve256"); + for (; ib + 1 < nb; ib += 2) { + const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0]; + const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1]; + const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; + const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; + + // load x + const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs); + const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs); + + // load y + const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); + const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); + + sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), + svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), + svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); + } break; + case 512: + { + // predicate for activating high 256 bit + const svbool_t ph32 = svptrue_pat_b8(SV_VL32); + // predicate for activating low 256 bit + const svbool_t pl32 = svnot_b_z(svptrue_b8(), ph32); + + // predicate for activating high lanes for 8 float32 elements + const svbool_t ph8 = svptrue_pat_b32(SV_VL8); + // predicate for activating low lanes for 8 float32 elements + const svbool_t pl8 = svnot_b_z(svptrue_b32(), ph8); + + svfloat32_t sumv00 = svdup_n_f32(0.0f); + + for (; ib + 1 < nb; ib += 2) { + const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0]; + const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1]; + const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; + const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; + + //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits + // and add them to make one 64 element vector + // load x + const svint8_t qx_32 = svld1_s8(ph32, x0->qs); + svint8_t qx_64 = svld1_s8(pl32, x0->qs + 2); + + qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64); + + // load y + const svint8_t qy_32 = svld1_s8(ph32, y0->qs); + svint8_t qy_64 = svld1_s8(pl32, y0->qs + 2); + + qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64); + + // scale creation + const float32_t deq1 = GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d); + const float32_t deq2 = GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d); + + // duplicate deq1 in first half of vector and deq2 in second half of vector + const svfloat32_t temp = svdup_f32_m(svdup_f32_z(ph8, deq1), pl8, deq2); + + const svfloat32_t sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64)); + + sumv00 = svmla_f32_m(svptrue_b32(), sumv00, sumvt, temp); + } + + sumf = svaddv_f32(svptrue_b32(), sumv00); + break; + } + default: + assert(false && "Unsupported vector length"); + break; + } +#elif defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + for (; ib + 1 < nb; ib += 2) { + const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0]; + const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1]; + const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; + const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; + + const int8x16_t x0_0 = vld1q_s8(x0->qs); + const int8x16_t x0_1 = vld1q_s8(x0->qs + 16); + const int8x16_t x1_0 = vld1q_s8(x1->qs); + const int8x16_t x1_1 = vld1q_s8(x1->qs + 16); + + // load y + const int8x16_t y0_0 = vld1q_s8(y0->qs); + const int8x16_t y0_1 = vld1q_s8(y0->qs + 16); + const int8x16_t y1_0 = vld1q_s8(y1->qs); + const int8x16_t y1_1 = vld1q_s8(y1->qs + 16); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0), + ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0), + ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); + } + + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#endif + for (; ib < nb; ++ib) { + int sumi = 0; + + for (int j = 0; j < qk; j++) { + sumi += x[ib].qs[j]*y[ib].qs[j]; + } + + sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)); + } + + *s = sumf; +} + +void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq1_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + float sumf = 0.0f; + + uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; + + const uint8x16_t shift = vld1q_u8(k_shift); + + for (int i = 0; i < nb; ++i) { +#if defined(__ARM_FEATURE_DOTPROD) + int32x4_t sumi0 = vdupq_n_s32(0); + int32x4_t sumi1 = vdupq_n_s32(0); +#else + int16x8_t sumi0 = vdupq_n_s16(0); + int16x8_t sumi1 = vdupq_n_s16(0); +#endif + + // first 32 bytes of 5 elements + { + uint8x16_t qx0 = vld1q_u8(x[i].qs + 0); + uint8x16_t qx1 = vld1q_u8(x[i].qs + 16); + uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3)); + uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3)); + uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9)); + uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9)); + uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27)); + uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27)); + uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81)); + uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81)); + + // multiply by 3 and keep the 2 bits above 8 bits + int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6)); + int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6)); + int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6)); + int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6)); + int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6)); + int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6)); + int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6)); + int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6)); + int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6)); + int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6)); + + const int8x16_t qy0 = vld1q_s8(y[i].qs + 0); + const int8x16_t qy1 = vld1q_s8(y[i].qs + 16); + const int8x16_t qy2 = vld1q_s8(y[i].qs + 32); + const int8x16_t qy3 = vld1q_s8(y[i].qs + 48); + const int8x16_t qy4 = vld1q_s8(y[i].qs + 64); + const int8x16_t qy5 = vld1q_s8(y[i].qs + 80); + const int8x16_t qy6 = vld1q_s8(y[i].qs + 96); + const int8x16_t qy7 = vld1q_s8(y[i].qs + 112); + const int8x16_t qy8 = vld1q_s8(y[i].qs + 128); + const int8x16_t qy9 = vld1q_s8(y[i].qs + 144); + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vdotq_s32(sumi0, sqx0, qy0); + sumi1 = vdotq_s32(sumi1, sqx1, qy1); + sumi0 = vdotq_s32(sumi0, sqx2, qy2); + sumi1 = vdotq_s32(sumi1, sqx3, qy3); + sumi0 = vdotq_s32(sumi0, sqx4, qy4); + sumi1 = vdotq_s32(sumi1, sqx5, qy5); + sumi0 = vdotq_s32(sumi0, sqx6, qy6); + sumi1 = vdotq_s32(sumi1, sqx7, qy7); + sumi0 = vdotq_s32(sumi0, sqx8, qy8); + sumi1 = vdotq_s32(sumi1, sqx9, qy9); +#else + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx8), vget_low_s8(qy8)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9)); +#endif + } + + // last 16 bytes of 5-element, along with the 4 bytes of 4 elements + { + uint8x16_t qx0 = vld1q_u8(x[i].qs + 32); + uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3)); + uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9)); + uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27)); + uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81)); + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned + uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh)); + qx5 = vmulq_u8(qx5, shift); + + // multiply by 3 and keep the 2 bits above 8 bits + int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6)); + int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6)); + int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6)); + int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6)); + int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6)); + int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6)); + + const int8x16_t qy0 = vld1q_s8(y[i].qs + 160); + const int8x16_t qy1 = vld1q_s8(y[i].qs + 176); + const int8x16_t qy2 = vld1q_s8(y[i].qs + 192); + const int8x16_t qy3 = vld1q_s8(y[i].qs + 208); + const int8x16_t qy4 = vld1q_s8(y[i].qs + 224); + const int8x16_t qy5 = vld1q_s8(y[i].qs + 240); + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vdotq_s32(sumi0, sqx0, qy0); + sumi1 = vdotq_s32(sumi1, sqx1, qy1); + sumi0 = vdotq_s32(sumi0, sqx2, qy2); + sumi1 = vdotq_s32(sumi1, sqx3, qy3); + sumi0 = vdotq_s32(sumi0, sqx4, qy4); + sumi1 = vdotq_s32(sumi1, sqx5, qy5); +#else + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); +#endif + } + + const int16x8_t ysum0 = vld1q_s16(y[i].bsums); + const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8); + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vaddq_s32(sumi0, sumi1); + sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1))); + + sumf += d * (float) vaddvq_s32(sumi0); +#else + sumi0 = vaddq_s16(sumi0, sumi1); + sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1)); + + sumf += d * (float) vaddlvq_s16(sumi0); +#endif + } + + *s = sumf; + +#else + const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243}; + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + int sum = 0; + + for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) { + for (size_t l = 0; l < 5; ++l) { + for (size_t m = 0; m < 32; ++m) { + uint8_t q = x[i].qs[j + m] * pow3[l]; + uint16_t xi = ((uint16_t) q * 3) >> 8; + sum += (xi - 1) * y[i].qs[j*5 + l*32 + m]; + } + } + } + for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) { + for (size_t l = 0; l < 5; ++l) { + for (size_t m = 0; m < 16; ++m) { + uint8_t q = x[i].qs[j + m] * pow3[l]; + uint16_t xi = ((uint16_t) q * 3) >> 8; + sum += (xi - 1) * y[i].qs[j*5 + l*16 + m]; + } + } + } + + for (size_t l = 0; l < 4; ++l) { + for (size_t j = 0; j < sizeof(x->qh); ++j) { + uint8_t q = x[i].qh[j] * pow3[l]; + uint16_t xi = ((uint16_t) q * 3) >> 8; + sum += (xi - 1) * y[i].qs[sizeof(x->qs)*5 + l*sizeof(x->qh) + j]; + } + } + + sumf += (float) sum * (GGML_FP16_TO_FP32(x[i].d) * y[i].d); + } + + *s = sumf; +#endif +} + +void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq2_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + float sumf = 0.0f; + + const uint8x16_t m3 = vdupq_n_u8(3); + + for (int i = 0; i < nb; ++i) { +#if defined(__ARM_FEATURE_DOTPROD) + int32x4_t sumi0 = vdupq_n_s32(0); + int32x4_t sumi1 = vdupq_n_s32(0); +#else + int16x8_t sumi0 = vdupq_n_s16(0); + int16x8_t sumi1 = vdupq_n_s16(0); +#endif + + for (size_t j = 0; j < sizeof(x->qs); j += 32) { + uint8x16_t qx0 = vld1q_u8(x[i].qs + j); + uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16); + uint8x16_t qx2 = vshrq_n_u8(qx0, 2); + uint8x16_t qx3 = vshrq_n_u8(qx1, 2); + uint8x16_t qx4 = vshrq_n_u8(qx0, 4); + uint8x16_t qx5 = vshrq_n_u8(qx1, 4); + uint8x16_t qx6 = vshrq_n_u8(qx0, 6); + uint8x16_t qx7 = vshrq_n_u8(qx1, 6); + + int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3)); + int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3)); + int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3)); + int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3)); + int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3)); + int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3)); + int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3)); + int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3)); + + const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 + 0); + const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 + 16); + const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 + 32); + const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 + 48); + const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 + 64); + const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 + 80); + const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96); + const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112); + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vdotq_s32(sumi0, sqx0, qy0); + sumi1 = vdotq_s32(sumi1, sqx1, qy1); + sumi0 = vdotq_s32(sumi0, sqx2, qy2); + sumi1 = vdotq_s32(sumi1, sqx3, qy3); + sumi0 = vdotq_s32(sumi0, sqx4, qy4); + sumi1 = vdotq_s32(sumi1, sqx5, qy5); + sumi0 = vdotq_s32(sumi0, sqx6, qy6); + sumi1 = vdotq_s32(sumi1, sqx7, qy7); +#else + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7)); +#endif + } + + const int16x8_t ysum0 = vld1q_s16(y[i].bsums); + const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8); + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vaddq_s32(sumi0, sumi1); + sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1))); + + sumf += d * (float) vaddvq_s32(sumi0); +#else + sumi0 = vaddq_s16(sumi0, sumi1); + sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1)); + + sumf += d * (float) vaddlvq_s16(sumi0); +#endif + } + + *s = sumf; + +#else + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + int32_t sumi = 0; + + for (size_t j = 0; j < sizeof(x->qs); j += 32) { + for (size_t l = 0; l < 4; ++l) { + for (size_t k = 0; k < 32; ++k) { + sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].qs[j + k] >> (l*2)) & 3) - 1); + } + } + } + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + sumf += (float) sumi * d; + } + + *s = sumf; +#endif +} + +void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q2_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_FEATURE_SVE + const int vector_length = svcntb()*8; + const svuint8_t m3s = svdup_n_u8(0x3); + const svuint32_t m4s = svdup_n_u32(0xF); + const svint32_t vzero_sv = svdup_n_s32(0); + svfloat32_t acc_sum = svdup_n_f32(0); + svbool_t pred_s32 = svptrue_pat_b32(SV_VL4); + + switch (vector_length) { + case 128: + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + svfloat32_t d_broad = svdup_n_f32((float32_t)d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin); + + const uint8_t * GGML_RESTRICT q2 = x[i].qs; + const int8_t * GGML_RESTRICT q8_sv = y[i].qs; + const uint8_t * GGML_RESTRICT sc = x[i].scales; + + svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc); + const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4)); + + mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+4); + const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4)); + + svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums); + svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+4); + + const svint32_t s0 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_2, q8sums_sv_2)); + + mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+8); + const svint32_t mins_sv_3 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4)); + + mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+12); + const svint32_t mins_sv_4 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4)); + + q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums+8); + q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+12); + + svint32_t s1 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_3, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_4, q8sums_sv_2)); + + svfloat32_t temp = svcvt_f32_s32_x(svptrue_b32(), svadd_s32_x(svptrue_b32(), s0, s1)); + + acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, temp, dmin_broad); + + svint32_t sumi1 = svdup_n_s32(0); + + { + const svuint8_t q2bits_1 = svld1_u8(svptrue_b8(), q2); + svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_1, m3s)); + svint8_t q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc), m4s)); + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 0)); + + const svuint8_t q2bits_3 = svld1_u8(svptrue_b8(), q2+16); + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_3, m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 1)); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 2), m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 2)); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 2), m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 3)); + + + const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+4), m4s)); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 4), m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 0)); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 4), m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 1)); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 6), m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 2)); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 6), m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 3)); + + //------------------------------- + + q2 += 32; + const svint32_t scales_sv_2 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+8), m4s)); + const svuint8_t q2bits_2 = svld1_u8(svptrue_b8(), q2); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_2, m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 0)); + + const svuint8_t q2bits_4 = svld1_u8(svptrue_b8(), q2+16); + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_4, m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 1)); + + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 2), m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 2)); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 2), m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 3)); + + + const svint32_t scales_sv_3 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+12), m4s)); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 4), m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 0)); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 4), m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 1)); + + + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 6), m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 2)); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 6), m3s)); + q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 3)); + } + acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, svcvt_f32_s32_x(svptrue_b32(), sumi1), d_broad); + } + *s = svaddv_f32(svptrue_b32(), acc_sum); + break; + + case 256: + case 512: + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + svfloat32_t d_broad = svdup_n_f32((float32_t)d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin); + + const uint8_t * GGML_RESTRICT q2 = x[i].qs; + const int8_t * GGML_RESTRICT q8_sv = y[i].qs; + const uint8_t * GGML_RESTRICT sc = x[i].scales; + + const svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc); sc += 8; + const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, m4s)); + const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, 4)); + svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums); + + const svuint32_t mins_and_scales_sve_1 = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc); + const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, m4s)); + const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, 4)); + + svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums+8); + + svfloat32_t temp = svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_2, q8sums_sv_2))); + + acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, temp, dmin_broad); + + svint32_t sumi1 = svdup_n_s32(0); + + { + const svuint8_t q2bits_1 = svld1_u8(svptrue_pat_b8(SV_VL32), q2); + svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_1, m3s)); + svint8_t q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + + svint32_t scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 0), svdup_lane_s32(scales_sv, 1)); + sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 2), m3s)); + q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + + svint32_t scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 2), svdup_lane_s32(scales_sv, 3)); + sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(svdup_n_s32(0), q2bytes_sv, q8bytes_sv), scale_2); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 4), m3s)); + q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + + scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 4), svdup_lane_s32(scales_sv, 5)); + sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 6), m3s)); + q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + + scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 6), svdup_lane_s32(scales_sv, 7)); + sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2); + + q2 += 32; + + const svuint8_t q2bits_2 = svld1_u8(svptrue_pat_b8(SV_VL32), q2); + q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_2, m3s)); + q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + + scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 0), svdup_lane_s32(scales_sv_1, 1)); + sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 2), m3s)); + q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + + scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 2), svdup_lane_s32(scales_sv_1, 3)); + sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 4), m3s)); + q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + + scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 4), svdup_lane_s32(scales_sv_1, 5)); + sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1); + + q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 6), m3s)); + q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + + scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 6), svdup_lane_s32(scales_sv_1, 7)); + sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2); + } + acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), sumi1), d_broad); + } + *s = svaddv_f32(svptrue_pat_b32(SV_VL8), acc_sum); + break; + + default: + assert(false && "Unsupported vector length"); + break; + } + +#elif __ARM_NEON + const uint8x16_t m3 = vdupq_n_u8(0x3); + const uint8x16_t m4 = vdupq_n_u8(0xF); + + const int32x4_t vzero = vdupq_n_s32(0); + + ggml_int8x16x2_t q2bytes; + uint8_t aux[16]; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * GGML_RESTRICT q2 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + const uint8_t * GGML_RESTRICT sc = x[i].scales; + + const uint8x16_t mins_and_scales = vld1q_u8(sc); + const uint8x16_t scales = vandq_u8(mins_and_scales, m4); + vst1q_u8(aux, scales); + + const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); + const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums); + const ggml_int16x8x2_t mins16 = {{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}}; + const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])), + vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0]))); + const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])), + vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1]))); + sum += dmin * vaddvq_s32(vaddq_s32(s0, s1)); + + int isum = 0; + int is = 0; + +// We use this macro instead of a function call because for some reason +// the code runs 2-3% slower, even if the function is declared inline +#define MULTIPLY_ACCUM_WITH_SCALE(index)\ + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\ + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)]; + +#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\ + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\ + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\ + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\ + MULTIPLY_ACCUM_WITH_SCALE((index)); + + for (int j = 0; j < QK_K/128; ++j) { + const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32; + + ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3)); + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3)); + + MULTIPLY_ACCUM_WITH_SCALE(0); + + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2); + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4); + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6); + + is += 8; + } + + sum += d * isum; + } + + *s = sum; + +#else + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + int summs = 0; + for (int j = 0; j < 16; ++j) { + summs += y[i].bsums[j] * (sc[j] >> 4); + } + + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + int isum = 0; + int is = 0; + int d; + for (int k = 0; k < QK_K/128; ++k) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + d = sc[is++] & 0xF; + int isuml = 0; + for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + d = sc[is++] & 0xF; + isuml = 0; + for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + shift += 2; + q8 += 32; + } + q2 += 32; + } + sumf += dall * isum - dmin * summs; + } + *s = sumf; +#endif +} + +void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_FEATURE_SVE) + + uint32_t aux[3]; + uint32_t utmp[4]; + + const int8_t m32 = 32; + const int vector_length = svcntb()*8; + const svuint8_t m3b_sv = svdup_n_u8(0x3); + const svint32_t vzero_sv = svdup_n_s32(0); + + const svuint8_t m0_sv = svdup_n_u8(1); + const svuint8_t m1_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 1); + const svuint8_t m2_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 2); + const svuint8_t m3_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 3); + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * GGML_RESTRICT q3_sv = x[i].qs; + const uint8_t * GGML_RESTRICT qh_sv = x[i].hmask; + const int8_t * GGML_RESTRICT q8_sv = y[i].qs; + + // Set up scales + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + + for (int j = 0; j < 16; ++j) scale[j] -= m32; + + switch (vector_length) { + case 128: + { + svuint8_t qhbits_sv_1 = svld1_u8(svptrue_b8(), qh_sv); + svuint8_t qhbits_sv_2 = svld1_u8(svptrue_b8(), qh_sv+16); + svuint8_t q3h_sv; + + svint32_t sumi1_1 = svdup_n_s32(0); + svint8_t q3bytes_sv; + + for (int j = 0; j < QK_K/128; ++j) { + + const svuint8_t q3bits_sv = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16; + const svuint8_t q3bits_sv_1 = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16; + svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_1), 2); + q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0])); + + q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_2), 2); + q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv_1, m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1])); + + q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_1), 1); + q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2])); + + q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_2), 1); + q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3])); + + + scale += 4; + q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_1); + q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0])); + + q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_2); + q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1])); + + + q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; + + q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_1), 1); + q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2])); + + q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_2), 1); + q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3])); + + if (j == 0) { + qhbits_sv_1 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_1, 4); + qhbits_sv_2 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_2, 4); + } + + scale += 4; + } + + sum += d * (svaddv_s32(svptrue_b32(), sumi1_1)); + } break; + case 256: + case 512: + { + svuint8_t qhbits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), qh_sv); + svuint8_t q3h_sv; + + svint32_t sumi1_1 = svdup_n_s32(0); + svint8_t q3bytes_sv; + + for (int j = 0; j < QK_K/128; ++j) { + + const svuint8_t q3bits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), q3_sv); q3_sv += 32; + svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + + q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m0_sv, qhbits_sv), 2); + q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + + svint32_t scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1])); + sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1); + + q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m1_sv, qhbits_sv), 1); + q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3])); + sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1); + + scale += 4; + q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; + + q3h_sv = svbic_u8_x(svptrue_pat_b8(SV_VL32), m2_sv, qhbits_sv); + q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1])); + sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1); + + q3h_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m3_sv, qhbits_sv), 1); + q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); + + scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3])); + sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1); + + if (j == 0) { + qhbits_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), qhbits_sv, 4); + } + + scale += 4; + } + + sum += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), sumi1_1)); + } break; + default: + assert(false && "Unsupported vector length"); + break; + } + } + *s = sum; + +#elif __ARM_NEON + + uint32_t aux[3]; + uint32_t utmp[4]; + + const uint8x16_t m3b = vdupq_n_u8(0x3); + const int32x4_t vzero = vdupq_n_s32(0); + + const uint8x16_t m0 = vdupq_n_u8(1); + const uint8x16_t m1 = vshlq_n_u8(m0, 1); + const uint8x16_t m2 = vshlq_n_u8(m0, 2); + const uint8x16_t m3 = vshlq_n_u8(m0, 3); + const int8_t m32 = 32; + + ggml_int8x16x4_t q3bytes; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); + + ggml_uint8x16x4_t q3h; + + int32_t isum = 0; + + // Set up scales + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= m32; + + for (int j = 0; j < QK_K/128; ++j) { + + const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32; + const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64; + const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64; + + q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2); + q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2); + q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1); + q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1); + + q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0])); + q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1])); + q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2])); + q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3])); + + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3]; + + scale += 4; + + q3h.val[0] = vbicq_u8(m2, qhbits.val[0]); + q3h.val[1] = vbicq_u8(m2, qhbits.val[1]); + q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1); + q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1); + + q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0])); + q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1])); + q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2])); + q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3])); + + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2]; + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3]; + + scale += 4; + + if (j == 0) { + qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4); + qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4); + } + + } + sum += d * isum; + + } + + *s = sum; + +#else + // scalar version + // This function is written like this so the compiler can manage to vectorize most of it + // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the + // manually vectorized version above. Every other version I tried would run at least 4 times slower. + // The ideal situation would be if we could just write the code once, and the compiler would + // automatically produce the best possible set of machine instructions, instead of us having to manually + // write vectorized versions for AVX, ARM_NEON, etc. + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + uint32_t auxs[4]; + const int8_t * scales = (const int8_t*)auxs; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT hm = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + q3 += 32; + } + a = aux8; + + memcpy(auxs, x[i].scales, 12); + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; + +#endif + +} + +void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); +#ifdef __ARM_FEATURE_MATMUL_INT8 + assert((nrc == 2) || (nrc == 1)); +#else + assert(nrc == 1); +#endif + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#if defined(__ARM_FEATURE_MATMUL_INT8) + if (nrc == 2) { + const block_q4_K * GGML_RESTRICT x0 = x; + const block_q4_K * GGML_RESTRICT x1 = (const block_q4_K *) ((const uint8_t *)vx + bx); + const block_q8_K * GGML_RESTRICT y0 = y; + const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by); + + const uint8x16_t m4b = vdupq_n_u8(0x0f); + + float32x4_t vfsum = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) { + const uint8_t * GGML_RESTRICT qx0 = x0->qs; + const uint8_t * GGML_RESTRICT qx1 = x1->qs; + const int8_t * GGML_RESTRICT qy0 = y0->qs; + const int8_t * GGML_RESTRICT qy1 = y1->qs; + + // decode scales and mins + int8_t x0_scales[8], x1_scales[8]; + int16x8_t x0_mins, x1_mins; + { + uint32_t scales_mins[3]; + memcpy(scales_mins, x0->scales, 12); + const uint32_t mins_0_3 = scales_mins[1] & kmask1; + const uint32_t mins_4_7 = ((scales_mins[2] >> 4) & kmask2) | (((scales_mins[1] >> 6) & kmask3) << 4); + const uint32x2_t mins = {mins_0_3, mins_4_7}; + x0_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins))); + uint32_t scales[2]; + scales[0] = scales_mins[0] & kmask1; // scales 0~3 + scales[1] = (scales_mins[2] & kmask2) | (((scales_mins[0] >> 6) & kmask3) << 4); // scales 4~7 + memcpy(x0_scales, scales, 8); + } + { + uint32_t scales_mins[3]; + memcpy(scales_mins, x1->scales, 12); + const uint32_t mins_0_3 = scales_mins[1] & kmask1; + const uint32_t mins_4_7 = ((scales_mins[2] >> 4) & kmask2) | (((scales_mins[1] >> 6) & kmask3) << 4); + const uint32x2_t mins = {mins_0_3, mins_4_7}; + x1_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins))); + uint32_t scales[2]; + scales[0] = scales_mins[0] & kmask1; // scales 0~3 + scales[1] = (scales_mins[2] & kmask2) | (((scales_mins[0] >> 6) & kmask3) << 4); // scales 4~7 + memcpy(x1_scales, scales, 8); + } + + int32x4_t visum = {0}; + + // process 64 data points per iteration, totally 256 data points + for (int j = 0; j < QK_K / 64; ++j, qx0 += 32, qx1 += 32, qy0 += 64, qy1 += 64) { + const int8x16x4_t vy0 = vld1q_s8_x4(qy0); + const int8x16x4_t vy1 = vld1q_s8_x4(qy1); + + int8x16_t vx0[4], vx1[4]; + { + const uint8x16x2_t vv = vld1q_u8_x2(qx0); + vx0[0] = vreinterpretq_s8_u8(vandq_u8(vv.val[0], m4b)); + vx0[1] = vreinterpretq_s8_u8(vandq_u8(vv.val[1], m4b)); + vx0[2] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[0], 4)); + vx0[3] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[1], 4)); + } + { + const uint8x16x2_t vv = vld1q_u8_x2(qx1); + vx1[0] = vreinterpretq_s8_u8(vandq_u8(vv.val[0], m4b)); + vx1[1] = vreinterpretq_s8_u8(vandq_u8(vv.val[1], m4b)); + vx1[2] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[0], 4)); + vx1[3] = vreinterpretq_s8_u8(vshrq_n_u8(vv.val[1], 4)); + } + + // process 32 data points (share same block scale) per iteration + for (int k = 0; k < 2; ++k) { + const int blk = j * 2 + k; + const int32x4_t block_scale = { + x0_scales[blk], + x0_scales[blk], + x1_scales[blk], + x1_scales[blk], + }; + + int32x4_t vr = {0}; + for (int l = 0; l < 2; ++l) { + const int idx = k * 2 + l; + const int64x2_t vx0_s64 = vreinterpretq_s64_s8(vx0[idx]); + const int64x2_t vx1_s64 = vreinterpretq_s64_s8(vx1[idx]); + const int64x2_t vy0_s64 = vreinterpretq_s64_s8(vy0.val[idx]); + const int64x2_t vy1_s64 = vreinterpretq_s64_s8(vy1.val[idx]); + const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vx0_s64, vx1_s64)); + const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vx0_s64, vx1_s64)); + const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vy0_s64, vy1_s64)); + const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vy0_s64, vy1_s64)); + vr = vmmlaq_s32(vr, vx_l, vy_l); + vr = vmmlaq_s32(vr, vx_h, vy_h); + } + // apply block scale, will NOT overflow + // block_scale * sum_256(int4*int8) <= 2^(8+8+4+8) = 28 bits + visum = vmlaq_s32(visum, vr, block_scale); + } + } + + // adjust bias, apply superblock scale + { + int32_t bias[4]; + // no obvious uplift from sve sdot-16, just use neon mul add + const int16x8_t y0_sums = vpaddq_s16(vld1q_s16(y0->bsums), vld1q_s16(y0->bsums+8)); + const int16x8_t y1_sums = vpaddq_s16(vld1q_s16(y1->bsums), vld1q_s16(y1->bsums+8)); + bias[0] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y0_sums), vget_low_s16(x0_mins)), + vmull_s16(vget_high_s16(y0_sums), vget_high_s16(x0_mins)))); + bias[1] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y1_sums), vget_low_s16(x0_mins)), + vmull_s16(vget_high_s16(y1_sums), vget_high_s16(x0_mins)))); + bias[2] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y0_sums), vget_low_s16(x1_mins)), + vmull_s16(vget_high_s16(y0_sums), vget_high_s16(x1_mins)))); + bias[3] = vaddvq_s32(vaddq_s32(vmull_s16(vget_low_s16(y1_sums), vget_low_s16(x1_mins)), + vmull_s16(vget_high_s16(y1_sums), vget_high_s16(x1_mins)))); + const float32x4_t dmins = { + GGML_FP16_TO_FP32(x0->dmin) * y0->d, + GGML_FP16_TO_FP32(x0->dmin) * y1->d, + GGML_FP16_TO_FP32(x1->dmin) * y0->d, + GGML_FP16_TO_FP32(x1->dmin) * y1->d, + }; + vfsum = vmlsq_f32(vfsum, vcvtq_f32_s32(vld1q_s32(bias)), dmins); + + const float32x4_t superblock_scale = { + GGML_FP16_TO_FP32(x0->d) * y0->d, + GGML_FP16_TO_FP32(x0->d) * y1->d, + GGML_FP16_TO_FP32(x1->d) * y0->d, + GGML_FP16_TO_FP32(x1->d) * y1->d, + }; + vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale); + } + } + + // vfsum = ABCD -> ACBD + // AC -> s, BD -> (s+bs) + vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2)); + vst1_f32(s, vget_low_f32 (vfsum)); + vst1_f32(s + bs, vget_high_f32(vfsum)); + + return; + } +#endif + +#ifdef __ARM_FEATURE_SVE + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + + memcpy(utmp, x[i].scales, K_SCALE_SIZE); + + uint32x2_t mins8 = { 0 }; + mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0); + mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1); + + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[0] &= kmask1; + + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); + const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); + sumf -= dmin * vaddvq_s32(prod); + + const uint8_t * scales = (const uint8_t *)utmp; + + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const int vector_length = ggml_cpu_get_sve_cnt()*8; + const svuint8_t m4b = svdup_n_u8(0xf); + const svint32_t mzero = svdup_n_s32(0); + svint32_t sumi1 = svdup_n_s32(0); + svint32_t sumi1_1 = svdup_n_s32(0); + svint32_t sumi1_2 = svdup_n_s32(0); + svint32_t sumi2 = svdup_n_s32(0); + svint32_t sumi2_1 = svdup_n_s32(0); + svint32_t sumi2_2 = svdup_n_s32(0); + switch (vector_length) { + case 128: + { + for (int j = 0; j < QK_K/64; ++j) { + svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b)); + svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; + sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]); + q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b)); + q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; + sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]); + + q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4)); + q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; + sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]); + q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4)); + q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; + sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]); + q4 += 32; + } + sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2); + sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2); + sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2))); + } break; + case 256: + case 512: + { + for (int j = 0; j < QK_K/64; ++j) { + const svuint8_t q4bits = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32; + svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b)); + svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32; + sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]); + + q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4)); + q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32; + sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]); + } + sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2))); + } break; + default: + assert(false && "Unsupported vector length"); + break; + } + } + *s = sumf; +#elif defined __ARM_NEON + const uint8x16_t m4b = vdupq_n_u8(0xf); + const int32x4_t mzero = vdupq_n_s32(0); + + ggml_int8x16x2_t q4bytes; + ggml_int8x16x2_t q8bytes; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + + memcpy(utmp, x[i].scales, 12); + + uint32x2_t mins8 = { 0 }; + mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0); + mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1); + + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[0] &= kmask1; + + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); + const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); + sumf -= dmin * vaddvq_s32(prod); + + const uint8_t * scales = (const uint8_t *)utmp; + + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + int32_t sumi1 = 0; + int32_t sumi2 = 0; + + for (int j = 0; j < QK_K/64; ++j) { + const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32; + + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + + const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + sumi1 += vaddvq_s32(p1) * scales[2*j+0]; + + q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + + const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + + sumi2 += vaddvq_s32(p2) * scales[2*j+1]; + } + + sumf += d * (sumi1 + sumi2); + + } + + *s = sumf; + +#else + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + a += 32; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + a += 32; q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + + +#ifdef __ARM_NEON + const uint8x16_t m4b = vdupq_n_u8(0xf); + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); + const int32x4_t mzero = vdupq_n_s32(0); + + ggml_int8x16x4_t q5bytes; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8); + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8)); + const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); + int32_t sumi_mins = vaddvq_s32(prod); + + const uint8_t * scales = (const uint8_t *)utmp; + + const uint8_t * GGML_RESTRICT q5 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); + + ggml_uint8x16x4_t q5h; + + int32_t sumi = 0; + + for (int j = 0; j < QK_K/64; ++j) { + + const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32; + const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; + + q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); + q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); + q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3); + q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3); + qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2); + qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2); + + q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0])); + q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1])); + q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2])); + q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3])); + + sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++; + sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++; + } + + sumf += d * sumi - dmin * sumi_mins; + } + + *s = sumf; + +#else + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const uint8_t * GGML_RESTRICT hm = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); +#ifdef __ARM_FEATURE_MATMUL_INT8 + assert((nrc == 2) || (nrc == 1)); +#else + assert(nrc == 1); +#endif + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_FEATURE_MATMUL_INT8) + if (nrc == 2) { + const block_q6_K * GGML_RESTRICT x0 = x; + const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx); + const block_q8_K * GGML_RESTRICT y0 = y; + const block_q8_K * GGML_RESTRICT y1 = (const block_q8_K *) ((const uint8_t *)vy + by); + + float32x4_t vfsum = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; ++i, ++x0, ++x1, ++y0, ++y1) { + const uint8_t * GGML_RESTRICT ql0 = x0->ql; + const uint8_t * GGML_RESTRICT ql1 = x1->ql; + const uint8_t * GGML_RESTRICT qh0 = x0->qh; + const uint8_t * GGML_RESTRICT qh1 = x1->qh; + const int8_t * GGML_RESTRICT qy0 = y0->qs; + const int8_t * GGML_RESTRICT qy1 = y1->qs; + + const uint8x16_t mone = vdupq_n_u8(0x30); + const uint8x16_t m4b = vdupq_n_u8(0x0f); + + int32x4_t visum = vdupq_n_s32(0); + + // process 8 blocks per iteration, totally 16 blocks + for (int j = 0; j < 2; ++j, qh0 += 32, ql0 += 64, qh1 += 32, ql1 += 64) { + int8x16_t vx0[8], vx1[8]; + + // de-quantize vx0[8] + { + const uint8x16x2_t qh_bits = vld1q_u8_x2(qh0); + const uint8x16x4_t ql_bits = vld1q_u8_x4(ql0); + + uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4)); + uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4)); + uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2)); + uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2)); + + vx0[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0)); + vx0[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1)); + vx0[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2)); + vx0[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3)); + + q6h_0 = vandq_u8(mone, qh_bits.val[0]); + q6h_1 = vandq_u8(mone, qh_bits.val[1]); + q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2)); + q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2)); + + vx0[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0)); + vx0[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1)); + vx0[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2)); + vx0[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3)); + } + + // de-quantize vx1[8] + { + const uint8x16x2_t qh_bits = vld1q_u8_x2(qh1); + const uint8x16x4_t ql_bits = vld1q_u8_x4(ql1); + + uint8x16_t q6h_0 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 4)); + uint8x16_t q6h_1 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 4)); + uint8x16_t q6h_2 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[0], 2)); + uint8x16_t q6h_3 = vandq_u8(mone, vshlq_n_u8(qh_bits.val[1], 2)); + + vx1[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[0], m4b), q6h_0)); + vx1[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[1], m4b), q6h_1)); + vx1[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[2], m4b), q6h_2)); + vx1[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(ql_bits.val[3], m4b), q6h_3)); + + q6h_0 = vandq_u8(mone, qh_bits.val[0]); + q6h_1 = vandq_u8(mone, qh_bits.val[1]); + q6h_2 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[0], 2)); + q6h_3 = vandq_u8(mone, vshrq_n_u8(qh_bits.val[1], 2)); + + vx1[4] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[0], 4), q6h_0)); + vx1[5] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[1], 4), q6h_1)); + vx1[6] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[2], 4), q6h_2)); + vx1[7] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(ql_bits.val[3], 4), q6h_3)); + } + + // process 16 elements (one block with same scale) per iteration + // - vx = concat(ql, qh) - 32 + // - r1,r2,r3,r4 = smmla(vx, vy) + for (int k = 0; k < 8; ++k) { + const int blk = j * 8 + k; + + const int8x16_t vy0 = vld1q_s8(qy0); + const int8x16_t vy1 = vld1q_s8(qy1); + qy0 += 16; + qy1 += 16; + + const int32x4_t block_scale = { + x0->scales[blk], + x0->scales[blk], + x1->scales[blk], + x1->scales[blk], + }; + + // calculate four results at once with outer product + const int8x16_t vx_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k]))); + const int8x16_t vx_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vx0[k]), vreinterpretq_s64_s8(vx1[k]))); + const int8x16_t vy_l = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1))); + const int8x16_t vy_h = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(vy0), vreinterpretq_s64_s8(vy1))); + int32x4_t vr = vdupq_n_s32(0); + vr = vmmlaq_s32(vr, vx_l, vy_l); + vr = vmmlaq_s32(vr, vx_h, vy_h); + + // apply block scale, will NOT overflow + // block_scale * sum_256(int6*int8) <= 2^(8+8+6+8) = 30 bits + visum = vmlaq_s32(visum, vr, block_scale); + } + } + + // adjust bias, apply superblock scale + { + int32_t bias[4]; +#ifdef __ARM_FEATURE_SVE + const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8); + const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8); + const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums); + const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8); + const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums); + const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8); + const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales)); + const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8)); + const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales)); + const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8)); + const svint64_t zero = svdup_n_s64(0); + bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0), + svdot_s64(zero, y0_q8sums_1, x0_q6scales_1))); + bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0), + svdot_s64(zero, y1_q8sums_1, x0_q6scales_1))); + bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0), + svdot_s64(zero, y0_q8sums_1, x1_q6scales_1))); + bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0), + svdot_s64(zero, y1_q8sums_1, x1_q6scales_1))); +#else + // NEON doesn't support int16 dot product, fallback to separated mul and add + const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums); + const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums); + + int8x16_t scales_s8 = vld1q_s8(x0->scales); + const int16x8x2_t q6scales0 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}}; + scales_s8 = vld1q_s8(x1->scales); + const int16x8x2_t q6scales1 = {{vmovl_s8(vget_low_s8(scales_s8)), vmovl_s8(vget_high_s8(scales_s8))}}; + + int32x4_t prod; + prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales0.val[0])), + vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales0.val[0]))), + vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales0.val[1])), + vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales0.val[1])))); + bias[0] = vaddvq_s32(prod); + prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales0.val[0])), + vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales0.val[0]))), + vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales0.val[1])), + vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales0.val[1])))); + bias[1] = vaddvq_s32(prod); + prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[0]), vget_low_s16 (q6scales1.val[0])), + vmull_s16(vget_high_s16(q8sums0.val[0]), vget_high_s16(q6scales1.val[0]))), + vaddq_s32(vmull_s16(vget_low_s16 (q8sums0.val[1]), vget_low_s16 (q6scales1.val[1])), + vmull_s16(vget_high_s16(q8sums0.val[1]), vget_high_s16(q6scales1.val[1])))); + bias[2] = vaddvq_s32(prod); + prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[0]), vget_low_s16 (q6scales1.val[0])), + vmull_s16(vget_high_s16(q8sums1.val[0]), vget_high_s16(q6scales1.val[0]))), + vaddq_s32(vmull_s16(vget_low_s16 (q8sums1.val[1]), vget_low_s16 (q6scales1.val[1])), + vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1])))); + bias[3] = vaddvq_s32(prod); + +#endif + const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32); + + const float32x4_t superblock_scale = { + GGML_FP16_TO_FP32(x0->d) * y0->d, + GGML_FP16_TO_FP32(x0->d) * y1->d, + GGML_FP16_TO_FP32(x1->d) * y0->d, + GGML_FP16_TO_FP32(x1->d) * y1->d, + }; + + visum = vsubq_s32(visum, vibias); + vfsum = vmlaq_f32(vfsum, vcvtq_f32_s32(visum), superblock_scale); + } + } + + // vfsum = ABCD -> ACBD + // AC -> s, BD -> (s+bs) + vfsum = vzip1q_f32(vfsum, vextq_f32(vfsum, vfsum, 2)); + vst1_f32(s, vget_low_f32 (vfsum)); + vst1_f32(s + bs, vget_high_f32(vfsum)); + + return; + } +#endif + +#ifdef __ARM_FEATURE_SVE + const int vector_length = ggml_cpu_get_sve_cnt()*8; + float sum = 0; + svuint8_t m4b = svdup_n_u8(0xf); + svint32_t vzero = svdup_n_s32(0); + svuint8_t mone = svdup_n_u8(0x30); + svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4; + svuint8_t q6h_1, q6h_2, q6h_3, q6h_4; + + for (int i = 0; i < nb; ++i) { + const float d_all = GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * GGML_RESTRICT q6 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const int8_t * GGML_RESTRICT scale = x[i].scales; + + const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8); + const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums); + const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8); + const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale)); + const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8)); + const svint64_t prod = svdup_n_s64(0); + int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1), + svdot_s64(prod, q8sums_2, q6scales_2))); + int32_t isum = 0; + + switch (vector_length) { + case 128: + { + const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4); + const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16); + svint32_t isum_tmp = svdup_n_s32(0); + for (int j = 0; j < QK_K/128; ++j) { + svuint8_t qhbits_1 = svld1_u8(pg8_16, qh); + svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16); + qh += 32; + svuint8_t q6bits_1 = svld1_u8(pg8_16, q6); + svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16); + svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32); + svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48); + q6 += 64; + svint8_t q8bytes_1 = svld1_s8(pg8_16, q8); + svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16); + svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32); + svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48); + q8 += 64; + + q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4)); + q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4)); + q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2)); + q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2)); + q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1)); + q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2)); + q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3)); + q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4)); + isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]); + isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]); + isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]); + isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]); + + scale += 4; + q8bytes_1 = svld1_s8(pg8_16, q8); + q8bytes_2 = svld1_s8(pg8_16, q8+16); + q8bytes_3 = svld1_s8(pg8_16, q8+32); + q8bytes_4 = svld1_s8(pg8_16, q8+48); + q8 += 64; + + q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1); + q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2); + q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2)); + q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2)); + q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1)); + q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2)); + q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3)); + q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4)); + isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]); + isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]); + isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]); + isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]); + scale += 4; + } + isum += svaddv_s32(pg32_4, isum_tmp); + sum += d_all * y[i].d * (isum - 32 * isum_mins); + } + break; + case 256: + case 512: + { + const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2); + const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8); + const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32); + svint32_t isum_tmp = svdup_n_s32(0); + for (int j = 0; j < QK_K/128; j++) { + svuint8_t qhbits_1 = svld1_u8(pg8_32, qh); + qh += 32; + svuint8_t q6bits_1 = svld1_u8(pg8_32, q6); + svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32); + q6 += 64; + svint8_t q8bytes_1 = svld1_s8(pg8_32, q8); + svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32); + svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64); + svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96); + q8 += 128; + q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4)); + q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2)); + q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1); + q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2)); + q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1)); + q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2)); + q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3)); + q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4)); + + svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale); + scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp); + scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp); + svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2); + scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp); + scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp); + svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4); + scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp); + scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp); + svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6); + scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp); + scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp); + svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp)); + svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp)); + svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp)); + svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp)); + + isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1); + isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2); + isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3); + isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4); + scale += 8; + } + isum += svaddv_s32(pg32_8, isum_tmp); + sum += d_all * y[i].d * (isum - 32 * isum_mins); + } + break; + default: + assert(false && "Unsupported vector length"); + break; + } + } + + *s = sum; + +#elif __ARM_NEON + float sum = 0; + + const uint8x16_t m4b = vdupq_n_u8(0xF); + const int32x4_t vzero = vdupq_n_s32(0); + //const int8x16_t m32s = vdupq_n_s8(32); + + const uint8x16_t mone = vdupq_n_u8(3); + + ggml_int8x16x4_t q6bytes; + ggml_uint8x16x4_t q6h; + + for (int i = 0; i < nb; ++i) { + + const float d_all = GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * GGML_RESTRICT q6 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const int8_t * GGML_RESTRICT scale = x[i].scales; + + const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums); + const int8x16_t scales = vld1q_s8(scale); + const ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}}; + + const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])), + vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))), + vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])), + vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1])))); + int32_t isum_mins = vaddvq_s32(prod); + + int32_t isum = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32; + ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64; + ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; + + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); + uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 2); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); + //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); + //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s); + //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s); + q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])); + q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])); + q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])); + q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])); + + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; + + scale += 4; + + q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; + + shifted = vshrq_n_u8(qhbits.val[0], 4); + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 4); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[0], 6); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 6); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s); + //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s); + //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s); + //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s); + q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])); + q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])); + q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])); + q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])); + + isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; + scale += 4; + } + //sum += isum * d_all * y[i].d; + sum += d_all * y[i].d * (isum - 32 * isum_mins); + + } + *s = sum; +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q4 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + a += 128; + q4 += 64; + qh += 32; + } + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/16; ++j) { + int scale = x[i].scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +#if defined (__ARM_NEON) +static const int8_t keven_signs_q2xs[1024] = { + 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, + 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, + 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1, + 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1, + 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, + 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1, + 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1, + 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, + 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1, + 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1, + 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1, + 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1, + 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1, + 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1, + 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1, + 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1, + 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1, + 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1, + 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1, + 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1, + 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1, + 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1, + 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, + 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1, + 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1, + 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1, + 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, + 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1, + 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1, + 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1, + 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1, + 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, +}; +#endif + +void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + ggml_int8x16x4_t q2u; + ggml_int8x16x4_t q2s; + ggml_int8x16x4_t q8b; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * GGML_RESTRICT q2 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + float sumf1 = 0, sumf2 = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; + q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1]))); + q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3]))); + q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9]))); + q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11]))); + q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127)))); + q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127)))); + q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 7) & 127)))); + q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127)))); + q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]); + q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]); + q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]); + q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]); + const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]); + const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]); + sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28)); + sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28)); + } + sumf += d*(sumf1 + sumf2); + } + *s = 0.25f * sumf; + +#else + + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * GGML_RESTRICT q2 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(aux32, q2, 2*sizeof(uint32_t)); + q2 += 4; + const uint32_t ls = 2*(aux32[1] >> 28) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); + const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls; + } + sumf += d * bsum; + } + *s = 0.125f * sumf; +#endif +} + +void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + ggml_int8x16x4_t q2u; + ggml_int8x16x4_t q2s; + ggml_int8x16x4_t q8b; + + int32x4x4_t scales32; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * GGML_RESTRICT q2 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + const uint8x8_t scales8 = vld1_u8(x[i].scales); + const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf)); + const uint8x8_t scales_h = vshr_n_u8(scales8, 4); + uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h)); + scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1)); + const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales)); + const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales)); + scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1))); + scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1))); + scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2))); + scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2))); + int32x4_t sumi = vdupq_n_s32(0); + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511)))); + q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511)))); + q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511)))); + q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[6] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[7] & 511)))); + q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[0] >> 9))), vld1_s8((const void *)(signs64 + (q2[1] >> 9)))); + q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[2] >> 9))), vld1_s8((const void *)(signs64 + (q2[3] >> 9)))); + q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[4] >> 9))), vld1_s8((const void *)(signs64 + (q2[5] >> 9)))); + q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[6] >> 9))), vld1_s8((const void *)(signs64 + (q2[7] >> 9)))); + q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]); + q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]); + q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]); + q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]); + const int32x4_t p1 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]); + const int32x4_t p2 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]); + const int32x4_t p3 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]); + const int32x4_t p4 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]); + const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4)); + sumi = vmlaq_s32(sumi, p, scales32.val[ib64]); + q2 += 8; + } + sumf += d*vaddvq_s32(sumi); + } + *s = 0.125f * sumf; + +#else + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * GGML_RESTRICT q2 = x[i].qs; + const uint8_t * GGML_RESTRICT sc = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1; + const uint16_t ls2 = 2*(sc[ib32] >> 4) + 1; + int32_t sumi = 0; + for (int l = 0; l < 2; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); + const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls1; + sumi = 0; + for (int l = 2; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); + const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls2; + q2 += 4; + } + sumf += d * bsum; + } + *s = 0.125f * sumf; +#endif +} + +void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; + + const ggml_uint8x16x2_t mask1 = ggml_vld1q_u8_x2(k_mask1); + const uint8x16_t mask2 = vld1q_u8(k_mask2); + const uint8x16_t m1 = vdupq_n_u8(1); + const int32x4_t vzero = vdupq_n_s32(0); + + uint8x16x2_t vs; + ggml_int8x16x4_t q2s; + ggml_int8x16x4_t q8b; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8); + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + int sumi1 = 0, sumi2 = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + q2s.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[0] | ((qh[ib32+0] << 8) & 0x300)))), + vld1_s8((const int8_t *)(iq2s_grid + (qs[1] | ((qh[ib32+0] << 6) & 0x300))))); + q2s.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[2] | ((qh[ib32+0] << 4) & 0x300)))), + vld1_s8((const int8_t *)(iq2s_grid + (qs[3] | ((qh[ib32+0] << 2) & 0x300))))); + q2s.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[4] | ((qh[ib32+1] << 8) & 0x300)))), + vld1_s8((const int8_t *)(iq2s_grid + (qs[5] | ((qh[ib32+1] << 6) & 0x300))))); + q2s.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[6] | ((qh[ib32+1] << 4) & 0x300)))), + vld1_s8((const int8_t *)(iq2s_grid + (qs[7] | ((qh[ib32+1] << 2) & 0x300))))); + qs += 8; + + vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16))); + vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); + vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); + vs.val[0] = vceqq_u8(vs.val[0], mask2); + vs.val[1] = vceqq_u8(vs.val[1], mask2); + + q2s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[0]); + q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]); + + vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16))); + vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); + vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); + vs.val[0] = vceqq_u8(vs.val[0], mask2); + vs.val[1] = vceqq_u8(vs.val[1], mask2); + + signs += 4; + + q2s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[2]); + q2s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[3]); + + const int32x4_t p1 = ggml_vdotq_s32(vzero, q2s.val[0], q8b.val[0]); + const int32x4_t p2 = ggml_vdotq_s32(vzero, q2s.val[1], q8b.val[1]); + const int32x4_t p3 = ggml_vdotq_s32(vzero, q2s.val[2], q8b.val[2]); + const int32x4_t p4 = ggml_vdotq_s32(vzero, q2s.val[3], q8b.val[3]); + + sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32+0] & 0xf)); + sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32+0] >> 4)); + sumi1 += vaddvq_s32(p3) * (1 + 2*(x[i].scales[ib32+1] & 0xf)); + sumi2 += vaddvq_s32(p4) * (1 + 2*(x[i].scales[ib32+1] >> 4)); + } + sumf += d*(sumi1 + sumi2); + } + + *s = 0.125f * sumf; + +#else + + float sumf = 0; + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint8_t * signs = qs + QK_K/8; + + int bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + int ls1 = 1 + 2*(x[i].scales[ib32] & 0xf); + int ls2 = 1 + 2*(x[i].scales[ib32] >> 4); + int sumi1 = 0, sumi2 = 0; + for (int l = 0; l < 2; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); + for (int j = 0; j < 8; ++j) { + sumi1 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + for (int l = 2; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); + for (int j = 0; j < 8; ++j) { + sumi2 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += ls1 * sumi1 + ls2 * sumi2; + qs += 4; + signs += 4; + } + + sumf += d * bsum; + } + + *s = 0.125f * sumf; + +#endif + +} + +void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq3_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[2]; + + ggml_int8x16x4_t q3s; + ggml_int8x16x4_t q8b; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + float sumf1 = 0, sumf2 = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t); + const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]); + const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]); + const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]); + const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]); + q3 += 16; + q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127)))); + q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127)))); + q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127)))); + q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127)))); + q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0)); + q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1)); + q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2)); + q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3)); + const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]); + const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]); + sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28)); + sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28)); + } + sumf += d*(sumf1 + sumf2); + } + *s = 0.5f * sumf; + +#else + + uint32_t aux32; + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t); + const uint32_t ls = 2*(aux32 >> 28) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]); + const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]); + const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127]; + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + q3 += 8; + bsum += sumi * ls; + } + sumf += d * bsum; + } + *s = 0.25f * sumf; +#endif +} + +void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq3_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + + typedef union { + uint16x8_t vec_index; + uint16_t index[8]; + } vec_index_t; + + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; + + static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1}; + + const ggml_uint8x16x2_t mask1 = ggml_vld1q_u8_x2(k_mask1); + const uint8x16_t mask2 = vld1q_u8(k_mask2); + + const int16x8_t hshift = vld1q_s16(k_shift); + const uint16x8_t m256 = vdupq_n_u16(256); + const uint8x16_t m1 = vdupq_n_u8(1); + + uint8x16x2_t vs; + ggml_int8x16x4_t q3s; + ggml_int8x16x4_t q8b; + vec_index_t idx; + + uint32_t scales32[2]; + const uint8_t * scales8 = (const uint8_t *)scales32; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + memcpy(scales32, x[i].scales, 4); + scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101; + scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101; + + int sumi1 = 0, sumi2 = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + + const uint8x16_t idx_l = vld1q_u8(qs); qs += 16; + idx.vec_index = vorrq_u16(vmovl_u8(vget_low_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+0]), hshift), m256)); + const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]], + iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]); + const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]], + iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]); + idx.vec_index = vorrq_u16(vmovl_u8(vget_high_u8(idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), hshift), m256)); + const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]], + iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]); + const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]], + iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]); + + + vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16))); + vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); + vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); + vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1); + vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1); + + q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0)); + q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1)); + + vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16))); + vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); + vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); + vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1); + vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1); + + signs += 4; + + q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_2)); + q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_3)); + + const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]); + const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]); + + sumi1 += vaddvq_s32(p1) * scales8[ib32/2+0]; + sumi2 += vaddvq_s32(p2) * scales8[ib32/2+4]; + } + sumf += d*(sumi1 + sumi2); + } + *s = sumf; + +#else + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT signs = x[i].signs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1; + const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + qs += 8; + signs += 4; + bsum += sumi * ls1; + sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + qs += 8; + signs += 4; + bsum += sumi * ls2; + } + sumf += d * bsum; + } + *s = sumf; +#endif +} + +void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined __ARM_NEON + + ggml_int8x16x4_t q1b; + ggml_int8x16x4_t q8b; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + int sumi1 = 0, sumi2 = 0, sumi3 = 0; + + for (int ib = 0; ib < QK_K/32; ib += 2) { + + q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[ib+0] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[ib+0] << 5) & 0x700))))); + q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[ib+0] << 2) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[ib+0] >> 1) & 0x700))))); + q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[ib+1] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[ib+1] << 5) & 0x700))))); + q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[ib+1] << 2) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[ib+1] >> 1) & 0x700))))); + qs += 8; + + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + + const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[0], q8b.val[0]), q1b.val[1], q8b.val[1]); + const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[2], q8b.val[2]), q1b.val[3], q8b.val[3]); + + const int ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; + const int ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; + sumi1 += vaddvq_s32(p1) * ls1; + sumi2 += vaddvq_s32(p2) * ls2; + sumi3 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * ls1 * (qh[ib+0] & 0x8000 ? -1 : 1) + + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * ls2 * (qh[ib+1] & 0x8000 ? -1 : 1); + + } + + sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * (sumi1 + sumi2 + IQ1S_DELTA * sumi3); + } + + *s = sumf; + +#else + + float sumf = 0; + for (int i = 0; i < nb; i++) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + int sumi = 0, sumi1 = 0; + for (int ib = 0; ib < QK_K/32; ++ib) { + const int ls = 2*((qh[ib] >> 12) & 7) + 1; + const int delta = qh[ib] & 0x8000 ? -1 : 1; + int lsum = 0; + for (int l = 0; l < 4; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8))); + for (int j = 0; j < 8; ++j) { + lsum += q8[j] * grid[j]; + } + q8 += 8; + } + sumi += ls * lsum; + sumi1 += ls * delta * (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]); + qs += 4; + } + + sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); + } + + *s = sumf; + +#endif +} + +void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_m * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + iq1m_scale_t scale; + +#if defined __ARM_NEON + const int32x4_t mask = vdupq_n_s32(0x7); + const int32x4_t mone = vdupq_n_s32(1); + const int32x4_t mzero = vdupq_n_s32(0); + + ggml_int8x16x4_t deltas; + deltas.val[0] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(+1)); + deltas.val[1] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(+1)); + deltas.val[2] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(-1)); + deltas.val[3] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(-1)); + + ggml_int8x16x4_t q1b; + ggml_int8x16x4_t q8b; + + uint32_t aux32; + const uint8_t * aux8 = (const uint8_t *)&aux32; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + int32x4_t sumi1 = mzero; + int32x4_t sumi2 = mzero; + + for (int ib = 0; ib < QK_K/32; ib += 2) { + + q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 4) & 0x700))))); + q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[1] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[1] << 4) & 0x700))))); + q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[2] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[2] << 4) & 0x700))))); + q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[3] << 8) & 0x700)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[3] << 4) & 0x700))))); + + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + + const int32x4_t p1 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(mzero, q1b.val[1], q8b.val[1])); + const int32x4_t p2 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(mzero, q1b.val[3], q8b.val[3])); + const int32x4_t p12 = vpaddq_s32(p1, p2); + + const uint32_t * qh32 = (const uint32_t *)qh; // we are 4-byte aligned, so we can do that + aux32 = ((qh32[0] >> 3) & 0x01010101) | ((qh32[0] >> 6) & 0x02020202); + + const int32x4_t p3 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[0]], q8b.val[0]), ggml_vdotq_s32(mzero, deltas.val[aux8[1]], q8b.val[1])); + const int32x4_t p4 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[2]], q8b.val[2]), ggml_vdotq_s32(mzero, deltas.val[aux8[3]], q8b.val[3])); + const int32x4_t p34 = vpaddq_s32(p3, p4); + + int32x4_t scales_4 = ggml_vld1q_u32(sc[ib/2] >> 0, sc[ib/2] >> 3, sc[ib/2] >> 6, sc[ib/2] >> 9); + + scales_4 = vaddq_s32(vshlq_n_s32(vandq_s32(scales_4, mask), 1), mone); + + sumi1 = vmlaq_s32(sumi1, scales_4, p12); + sumi2 = vmlaq_s32(sumi2, scales_4, p34); + + qs += 8; qh += 4; + + } + + sumf += y[i].d * GGML_FP16_TO_FP32(scale.f16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2)); + } + + *s = sumf; + +#else + + int sum1[2], sum2[2], delta[4]; + + float sumf = 0; + for (int i = 0; i < nb; i++) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + int sumi1 = 0, sumi2 = 0; + for (int ib = 0; ib < QK_K/32; ++ib) { + delta[0] = qh[0] & 0x08 ? -1 : 1; + delta[1] = qh[0] & 0x80 ? -1 : 1; + delta[2] = qh[1] & 0x08 ? -1 : 1; + delta[3] = qh[1] & 0x80 ? -1 : 1; + sum1[0] = sum1[1] = sum2[0] = sum2[1] = 0; + for (int l = 0; l < 4; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((uint16_t)qh[l/2] << (8 - 4*(l%2))) & 0x700))); + int lsum1 = 0, lsum2 = 0; + for (int j = 0; j < 8; ++j) { + lsum1 += q8[j] * grid[j]; + lsum2 += q8[j]; + } + q8 += 8; + sum1[l/2] += lsum1; + sum2[l/2] += lsum2*delta[l]; + } + + const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1; + const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1; + + sumi1 += sum1[0] * ls1 + sum1[1] * ls2; + sumi2 += sum2[0] * ls1 + sum2[1] * ls2; + qs += 4; + qh += 2; + } + + sumf += GGML_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2); + } + + *s = sumf; + +#endif +} + +void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK4_NL == 0); + static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same"); + + const block_iq4_nl * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + const int nb = n / QK4_NL; + + int ib = 0; + float sumf = 0; + +#if defined __ARM_NEON + const int8x16_t values = vld1q_s8(kvalues_iq4nl); + const uint8x16_t m4b = vdupq_n_u8(0x0f); + uint8x16x2_t q4bits; + int8x16x4_t q4b; + int8x16x4_t q8b; + int32x4_t prod_1, prod_2; + + for (; ib + 1 < nb; ib += 2) { + + q4bits.val[0] = vld1q_u8(x[ib + 0].qs); + q4bits.val[1] = vld1q_u8(x[ib + 1].qs); + q8b.val[0] = vld1q_s8(y[ib + 0].qs); + q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16); + q8b.val[2] = vld1q_s8(y[ib + 1].qs); + q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16); + + q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b)); + q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4)); + q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b)); + q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4)); + + prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]); + prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]); + + sumf += + GGML_FP16_TO_FP32(x[ib+0].d) * GGML_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) + + GGML_FP16_TO_FP32(x[ib+1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2); + } + +#endif + for (; ib < nb; ++ib) { + const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d); + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < QK4_NL/2; ++j) { + sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf]; + sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4]; + } + sumf += d * (sumi1 + sumi2); + } + *s = sumf; +} + +void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); + + const block_iq4_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined __ARM_NEON + const int8x16_t values = vld1q_s8(kvalues_iq4nl); + const uint8x16_t m4b = vdupq_n_u8(0x0f); + ggml_uint8x16x2_t q4bits; + ggml_int8x16x4_t q4b; + ggml_int8x16x4_t q8b; + int32x4_t prod_1, prod_2; + + float sumf = 0; + + for (int ibl = 0; ibl < nb; ++ibl) { + + const int8_t * q8 = y[ibl].qs; + const uint8_t * q4 = x[ibl].qs; + uint16_t h = x[ibl].scales_h; + + int sumi1 = 0, sumi2 = 0; + for (int ib = 0; ib < QK_K/64; ++ib) { + + q4bits = ggml_vld1q_u8_x2(q4); q4 += 32; + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + + q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b)); + q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4)); + q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b)); + q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4)); + + prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]); + prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]); + + int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32; + int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32; + h >>= 4; + sumi1 += vaddvq_s32(prod_1) * ls1; + sumi2 += vaddvq_s32(prod_2) * ls2; + + } + + sumf += GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2); + } + + *s = sumf; + +#else + float sumf = 0; + for (int ibl = 0; ibl < nb; ++ibl) { + const float d4d8 = GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d; + uint16_t h = x[ibl].scales_h; + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + for (int ib = 0; ib < QK_K/32; ib += 2) { + const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30); + const uint8_t ls2 = (x[ibl].scales_l[ib/2] >> 4) | ((h << 2) & 0x30); + h >>= 4; + const float d1 = d4d8*(ls1 - 32); + const float d2 = d4d8*(ls2 - 32); + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 16; ++j) { + sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; + sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; + } + sumf += d1 * (sumi1 + sumi2); + qs += 16; + q8 += 32; + sumi1 = sumi2 = 0; + for (int j = 0; j < 16; ++j) { + sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; + sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; + } + sumf += d2 * (sumi1 + sumi2); + qs += 16; + q8 += 32; + } + } + *s = sumf; +#endif +} + diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp new file mode 100644 index 000000000..9337e01b6 --- /dev/null +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -0,0 +1,2174 @@ +#define GGML_COMMON_IMPL_CPP +#define GGML_COMMON_DECL_CPP +#include "ggml-common.h" +#include "ggml-backend-impl.h" + +#include "ggml-impl.h" +#include "ggml-cpu.h" +#include "ggml-cpu-impl.h" +#include "traits.h" + +#include +#include +#include +#include // for qsort +#include // for GGML_ASSERT + +#define GGML_CPU_CLANG_WORKAROUND +#include "../../repack.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Woverlength-strings" +#endif + +#define UNUSED GGML_UNUSED + +void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; + +#if defined(__ARM_NEON) + float32x4_t srcv[4][8]; + float id[4]; + + for (int i = 0; i < nb; i++) { + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int row_iter = 0; row_iter < 4; row_iter++) { + for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); + + for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); + for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); + for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < 8; j++) { + float32x4_t v = vmulq_n_f32(srcv[0][j], id[0]); + int32x4_t vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 3] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[1][j], id[1]); + vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 4] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 5] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 6] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 7] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[2][j], id[2]); + vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 8] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 9] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 10] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 11] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[3][j], id[3]); + vi = vcvtnq_s32_f32(v); + y[i].qs[16 * j + 12] = vgetq_lane_s32(vi, 0); + y[i].qs[16 * j + 13] = vgetq_lane_s32(vi, 1); + y[i].qs[16 * j + 14] = vgetq_lane_s32(vi, 2); + y[i].qs[16 * j + 15] = vgetq_lane_s32(vi, 3); + } + } +#else + // scalar + const int blck_size_interleave = 4; + float srcv[4][QK8_0]; + float id[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; + amax = MAX(amax, fabsf(srcv[row_iter][j])); + } + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < QK8_0 * 4; j++) { + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); + + float x0 = srcv[src_id][src_offset] * id[src_id]; + y[i].qs[j] = roundf(x0); + } + } +#endif +} + +void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; + +#if defined(__ARM_NEON) + float32x4_t srcv[4][8]; + float id[4]; + + for (int i = 0; i < nb; i++) { + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int row_iter = 0; row_iter < 4; row_iter++) { + for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); + + for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); + for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); + for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < 4; j++) { + float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]); + int32x4_t vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 3] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 4] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 5] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 6] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 7] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[1][2 * j], id[1]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 8] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 9] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 10] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 11] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 12] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 13] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 14] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 15] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[2][2 * j], id[2]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 16] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 17] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 18] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 19] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[2][2 * j + 1], id[2]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 20] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 21] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 22] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 23] = vgetq_lane_s32(vi, 3); + + v = vmulq_n_f32(srcv[3][2 * j], id[3]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 24] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 25] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 26] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 27] = vgetq_lane_s32(vi, 3); + v = vmulq_n_f32(srcv[3][2 * j + 1], id[3]); + vi = vcvtnq_s32_f32(v); + y[i].qs[32 * j + 28] = vgetq_lane_s32(vi, 0); + y[i].qs[32 * j + 29] = vgetq_lane_s32(vi, 1); + y[i].qs[32 * j + 30] = vgetq_lane_s32(vi, 2); + y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3); + } + } + +#else + // scalar + const int blck_size_interleave = 8; + float srcv[4][QK8_0]; + float id[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; + amax = MAX(amax, fabsf(srcv[row_iter][j])); + } + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < QK8_0 * 4; j++) { + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); + + float x0 = srcv[src_id][src_offset] * id[src_id]; + y[i].qs[j] = roundf(x0); + } + } +#endif +} + +void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx; + + for (int c = 0; c < nc; c += ncols_interleaved) { + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + float32x4_t acc = vdupq_n_f32(0); + for (int b = 0; b < nb; b++) { + int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs); + int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16); + int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32); + int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48); + float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d); + + int8x16_t a0 = vld1q_s8(a_ptr->qs); + int8x16_t a1 = vld1q_s8(a_ptr->qs + qk/2); + float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d); + + int32x4_t ret = vdupq_n_s32(0); + + ret = vdotq_laneq_s32(ret, b0 << 4, a0, 0); + ret = vdotq_laneq_s32(ret, b1 << 4, a0, 1); + ret = vdotq_laneq_s32(ret, b2 << 4, a0, 2); + ret = vdotq_laneq_s32(ret, b3 << 4, a0, 3); + + ret = vdotq_laneq_s32(ret, b0 & 0xf0U, a1, 0); + ret = vdotq_laneq_s32(ret, b1 & 0xf0U, a1, 1); + ret = vdotq_laneq_s32(ret, b2 & 0xf0U, a1, 2); + ret = vdotq_laneq_s32(ret, b3 & 0xf0U, a1, 3); + + acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4), + vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd))); + a_ptr++; + b_ptr++; + } + vst1q_f32(s, acc); + s += ncols_interleaved; + } + return; + } +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx; + + for (int c = 0; c < nc; c += ncols_interleaved) { + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + float32x4_t acc = vdupq_n_f32(0); + for (int b = 0; b < nb; b++) { + int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs); + int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16); + int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32); + int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48); + float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d); + + int8x16_t a0 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs); + int8x16_t a1 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 1); + int8x16_t a2 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 2); + int8x16_t a3 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 3); + float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d); + + int32x4_t ret0 = vdupq_n_s32(0); + int32x4_t ret1 = vdupq_n_s32(0); + + ret0 = vdotq_s32(ret0, b0 << 4, a0); + ret1 = vdotq_s32(ret1, b1 << 4, a0); + ret0 = vdotq_s32(ret0, b2 << 4, a1); + ret1 = vdotq_s32(ret1, b3 << 4, a1); + + ret0 = vdotq_s32(ret0, b0 & 0xf0U, a2); + ret1 = vdotq_s32(ret1, b1 & 0xf0U, a2); + ret0 = vdotq_s32(ret0, b2 & 0xf0U, a3); + ret1 = vdotq_s32(ret1, b3 & 0xf0U, a3); + + int32x4_t ret = vpaddq_s32(ret0, ret1); + + acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4), + vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd))); + a_ptr++; + b_ptr++; + } + vst1q_f32(s, acc); + s += ncols_interleaved; + } + return; + } +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) +#if defined(__ARM_FEATURE_SVE) + if (ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) { + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + + __asm__ __volatile__( + "ptrue p0.b\n" + "add %x[b_ptr], %x[b_ptr], #0x10\n" + "1:" // Column loop + "add x22, %x[a_ptr], #0x2\n" + "mov z31.b, #0x0\n" + "mov x21, %x[nb]\n" + "2:" // Block loop + "ld1b { z30.b }, p0/Z, [%x[b_ptr]]\n" + "ld1b { z29.b }, p0/Z, [%x[b_ptr], #1, MUL VL]\n" + "mov z28.s, #0x0\n" + "mov z27.s, #0x0\n" + "ld1rd { z26.d }, p0/Z, [x22]\n" + "ld1b { z25.b }, p0/Z, [%x[b_ptr], #2, MUL VL]\n" + "sub x20, x22, #0x2\n" + "sub x21, x21, #0x1\n" + "ld1b { z24.b }, p0/Z, [%x[b_ptr], #3, MUL VL]\n" + "ld1rd { z23.d }, p0/Z, [x22, #8]\n" + "lsl z22.b, z30.b, #0x4\n" + "lsl z16.b, z29.b, #0x4\n" + "and z30.b, z30.b, #0xf0\n" + "and z29.b, z29.b, #0xf0\n" + "ld1rd { z21.d }, p0/Z, [x22, #16]\n" + "ld1rd { z20.d }, p0/Z, [x22, #24]\n" + "lsl z19.b, z25.b, #0x4\n" + "and z25.b, z25.b, #0xf0\n" + "ld1rh { z17.h }, p0/Z, [x20]\n" + "ld1h { z18.s }, p0/Z, [%x[b_ptr], #-1, MUL VL]\n" + "sdot z28.s, z22.b, z26.b\n" + "sdot z27.s, z16.b, z26.b\n" + "lsl z16.b, z24.b, #0x4\n" + "add x22, x22, #0x22\n" + "and z24.b, z24.b, #0xf0\n" + "add %x[b_ptr], %x[b_ptr], #0x90\n" + "fcvt z17.s, p0/m, z17.h\n" + "fcvt z18.s, p0/m, z18.h\n" + "sdot z28.s, z19.b, z23.b\n" + "sdot z27.s, z16.b, z23.b\n" + "fmul z18.s, z18.s, z17.s\n" + "sdot z28.s, z30.b, z21.b\n" + "sdot z27.s, z29.b, z21.b\n" + "sdot z28.s, z25.b, z20.b\n" + "sdot z27.s, z24.b, z20.b\n" + "uzp1 z17.s, z28.s, z27.s\n" + "uzp2 z16.s, z28.s, z27.s\n" + "add z17.s, z17.s, z16.s\n" + "asr z17.s, z17.s, #0x4\n" + "scvtf z17.s, p0/m, z17.s\n" + "fmla z31.s, p0/M, z17.s, z18.s\n" + "cbnz x21, 2b\n" + "sub %x[nc], %x[nc], #0x8\n" + "st1w { z31.s }, p0, [%x[res_ptr]]\n" + "add %x[res_ptr], %x[res_ptr], #0x20\n" + "cbnz %x[nc], 1b\n" + : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) + : [a_ptr] "r" (a_ptr), [nb] "r" (nb) + : "memory", "p0", "x20", "x21", "x22", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); + return; + } +#endif // #if defined(__ARM_FEATURE_SVE) + +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) + { + float sumf[8]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } + } +} + +void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl); + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + float * res_ptr = s; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + + float32x4_t sumf = vdupq_n_f32(0); + for (int l = 0; l < nb; l++) { + uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0); + uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16); + uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32); + uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48); + + int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4); + int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F); + int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4); + int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F); + int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4); + int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F); + int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4); + int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F); + + int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0); + int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16); + + int32x4_t sumi = vdupq_n_s32(0); + sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0); + sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0); + sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1); + sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1); + sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2); + sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2); + sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3); + sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3); + + float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d)); + float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d)); + float32x4_t d = a_d * b_d; + + sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi)); + } + + vst1q_f32(res_ptr + x * 4, sumf); + } + return; + } +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) + { + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + } + sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } + } +} + +void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + size_t res_stride = bs * sizeof(float); + + __asm__ __volatile__( + "mov x10, %x[nr]\n" + "mov x9, #0x88\n" + "cmp x10, #0x10\n" + "mul x9, %x[nb], x9\n" + "blt 4f\n" + "1:" // Row loop + "add x28, %x[b_ptr], #0x8\n" + "mov x27, %x[nc]\n" + "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" + "2:" // Column loop + "add x25, %x[a_ptr], #0x8\n" + "movi v15.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "mov x24, %x[nb]\n" + "add x23, x25, x9\n" + "movi v18.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "add x22, x23, x9\n" + "movi v11.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "add x21, x22, x9\n" + "movi v23.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v7.16b, #0x0\n" + "movi v0.16b, #0x0\n" + "movi v4.16b, #0x0\n" + "movi v5.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v8.16b, #0x0\n" + "movi v1.16b, #0x0\n" + "3:" // Block loop + "ldr q3, [x28, #0x0]\n" + "ldr q31, [x25, #0x0]\n" + "movi v28.16b, #0x4\n" + "movi v10.4s, #0x0\n" + "ldr q22, [x28, #0x10]\n" + "ldr q6, [x25, #0x10]\n" + "movi v29.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "ldr q27, [x28, #0x20]\n" + "ldr q30, [x28, #0x30]\n" + "movi v20.4s, #0x0\n" + "movi v24.16b, #0xf0\n" + "ldr d2, [x25, #-0x8]\n" + "ldr d26, [x23, #-0x8]\n" + "sshl v12.16b, v3.16b, v28.16b\n" + "sub x20, x28, #0x8\n" + "ldr d17, [x20, #0x0]\n" + "and v3.16b, v3.16b, v24.16b\n" + "subs x24, x24, #0x1\n" + "add x28, x28, #0x48\n" + ".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n" + ".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n" + ".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n" + ".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n" + "sshl v31.16b, v22.16b, v28.16b\n" + "and v22.16b, v22.16b, v24.16b\n" + "fcvtl v17.4s, v17.4h\n" + "fcvtl v2.4s, v2.4h\n" + "fcvtl v26.4s, v26.4h\n" + ".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n" + ".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n" + ".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n" + ".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n" + "sshl v6.16b, v27.16b, v28.16b\n" + "sshl v28.16b, v30.16b, v28.16b\n" + "and v27.16b, v27.16b, v24.16b\n" + "and v30.16b, v30.16b, v24.16b\n" + "ldr q24, [x25, #0x20]\n" + ".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n" + ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" + ".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n" + ".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x30]\n" + ".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n" + ".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n" + ".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n" + ".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x40]\n" + ".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n" + ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" + ".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n" + ".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x50]\n" + ".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n" + ".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n" + ".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n" + ".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x60]\n" + ".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n" + ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" + ".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n" + ".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x70]\n" + "add x25, x25, #0x88\n" + ".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n" + ".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n" + ".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n" + ".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n" + "fmul v24.4s, v17.4s, v2.s[0]\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v15.4s, v10.4s, v24.4s\n" + "ldr q24, [x23, #0x0]\n" + "fmul v10.4s, v17.4s, v2.s[1]\n" + "fmla v19.4s, v29.4s, v10.4s\n" + "ldr q10, [x23, #0x10]\n" + "fmul v29.4s, v17.4s, v2.s[2]\n" + "fmul v2.4s, v17.4s, v2.s[3]\n" + "fmla v18.4s, v9.4s, v29.4s\n" + "movi v9.4s, #0x0\n" + "movi v29.4s, #0x0\n" + ".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n" + ".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n" + "fmla v14.4s, v20.4s, v2.4s\n" + "movi v20.4s, #0x0\n" + "movi v2.4s, #0x0\n" + ".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n" + ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x20]\n" + ".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n" + ".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n" + ".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n" + ".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x30]\n" + ".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n" + ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" + ".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n" + ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x40]\n" + ".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n" + ".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n" + ".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n" + ".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x50]\n" + ".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n" + ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" + ".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n" + ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x60]\n" + ".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n" + ".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n" + ".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n" + ".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x70]\n" + "add x23, x23, #0x88\n" + ".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n" + ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" + ".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n" + ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x0]\n" + ".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n" + ".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n" + ".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n" + ".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n" + "fmul v10.4s, v17.4s, v26.s[0]\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "fmla v11.4s, v9.4s, v10.4s\n" + "ldr q9, [x22, #0x10]\n" + "fmul v10.4s, v17.4s, v26.s[1]\n" + "fmla v13.4s, v29.4s, v10.4s\n" + "ldr d29, [x22, #-0x8]\n" + "fmul v10.4s, v17.4s, v26.s[2]\n" + "fmul v26.4s, v17.4s, v26.s[3]\n" + "fcvtl v29.4s, v29.4h\n" + "fmla v23.4s, v20.4s, v10.4s\n" + "movi v20.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "fmla v16.4s, v2.4s, v26.4s\n" + "movi v26.4s, #0x0\n" + "movi v2.4s, #0x0\n" + ".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n" + ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" + ".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n" + ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x20]\n" + ".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n" + ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" + ".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n" + ".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n" + "ldr q9, [x22, #0x30]\n" + ".inst 0x4f98e0d4 // sdot v20.4s, v6.16b, v24.4b[0]\n" + ".inst 0x4fb8e0ca // sdot v10.4s, v6.16b, v24.4b[1]\n" + ".inst 0x4f98e8da // sdot v26.4s, v6.16b, v24.4b[2]\n" + ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x40]\n" + ".inst 0x4f89e394 // sdot v20.4s, v28.16b, v9.4b[0]\n" + ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n" + ".inst 0x4f89eb9a // sdot v26.4s, v28.16b, v9.4b[2]\n" + ".inst 0x4fa9eb82 // sdot v2.4s, v28.16b, v9.4b[3]\n" + "ldr q9, [x22, #0x50]\n" + ".inst 0x4f98e074 // sdot v20.4s, v3.16b, v24.4b[0]\n" + ".inst 0x4fb8e06a // sdot v10.4s, v3.16b, v24.4b[1]\n" + ".inst 0x4f98e87a // sdot v26.4s, v3.16b, v24.4b[2]\n" + ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x60]\n" + ".inst 0x4f89e2d4 // sdot v20.4s, v22.16b, v9.4b[0]\n" + ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n" + ".inst 0x4f89eada // sdot v26.4s, v22.16b, v9.4b[2]\n" + ".inst 0x4fa9eac2 // sdot v2.4s, v22.16b, v9.4b[3]\n" + "ldr q9, [x22, #0x70]\n" + "add x22, x22, #0x88\n" + ".inst 0x4f98e374 // sdot v20.4s, v27.16b, v24.4b[0]\n" + ".inst 0x4fb8e36a // sdot v10.4s, v27.16b, v24.4b[1]\n" + ".inst 0x4f98eb7a // sdot v26.4s, v27.16b, v24.4b[2]\n" + ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" + "ldr q24, [x21, #0x0]\n" + ".inst 0x4f89e3d4 // sdot v20.4s, v30.16b, v9.4b[0]\n" + ".inst 0x4fa9e3ca // sdot v10.4s, v30.16b, v9.4b[1]\n" + ".inst 0x4f89ebda // sdot v26.4s, v30.16b, v9.4b[2]\n" + ".inst 0x4fa9ebc2 // sdot v2.4s, v30.16b, v9.4b[3]\n" + "fmul v9.4s, v17.4s, v29.s[0]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "fmla v25.4s, v20.4s, v9.4s\n" + "ldr q9, [x21, #0x10]\n" + "fmul v20.4s, v17.4s, v29.s[1]\n" + "fmla v7.4s, v10.4s, v20.4s\n" + "ldr d20, [x21, #-0x8]\n" + "fmul v10.4s, v17.4s, v29.s[2]\n" + "fmul v29.4s, v17.4s, v29.s[3]\n" + "fcvtl v20.4s, v20.4h\n" + "fmla v0.4s, v26.4s, v10.4s\n" + "movi v26.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "fmla v4.4s, v2.4s, v29.4s\n" + "movi v2.4s, #0x0\n" + "movi v29.4s, #0x0\n" + ".inst 0x4f98e19a // sdot v26.4s, v12.16b, v24.4b[0]\n" + ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" + ".inst 0x4f98e982 // sdot v2.4s, v12.16b, v24.4b[2]\n" + ".inst 0x4fb8e99d // sdot v29.4s, v12.16b, v24.4b[3]\n" + "ldr q12, [x21, #0x20]\n" + "fmul v24.4s, v17.4s, v20.s[0]\n" + ".inst 0x4f89e3fa // sdot v26.4s, v31.16b, v9.4b[0]\n" + ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" + ".inst 0x4f89ebe2 // sdot v2.4s, v31.16b, v9.4b[2]\n" + ".inst 0x4fa9ebfd // sdot v29.4s, v31.16b, v9.4b[3]\n" + "ldr q9, [x21, #0x30]\n" + "fmul v31.4s, v17.4s, v20.s[1]\n" + ".inst 0x4f8ce0da // sdot v26.4s, v6.16b, v12.4b[0]\n" + ".inst 0x4face0ca // sdot v10.4s, v6.16b, v12.4b[1]\n" + ".inst 0x4f8ce8c2 // sdot v2.4s, v6.16b, v12.4b[2]\n" + ".inst 0x4face8dd // sdot v29.4s, v6.16b, v12.4b[3]\n" + "ldr q12, [x21, #0x40]\n" + "fmul v6.4s, v17.4s, v20.s[2]\n" + "fmul v20.4s, v17.4s, v20.s[3]\n" + ".inst 0x4f89e39a // sdot v26.4s, v28.16b, v9.4b[0]\n" + ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n" + ".inst 0x4f89eb82 // sdot v2.4s, v28.16b, v9.4b[2]\n" + ".inst 0x4fa9eb9d // sdot v29.4s, v28.16b, v9.4b[3]\n" + "ldr q9, [x21, #0x50]\n" + ".inst 0x4f8ce07a // sdot v26.4s, v3.16b, v12.4b[0]\n" + ".inst 0x4face06a // sdot v10.4s, v3.16b, v12.4b[1]\n" + ".inst 0x4f8ce862 // sdot v2.4s, v3.16b, v12.4b[2]\n" + ".inst 0x4face87d // sdot v29.4s, v3.16b, v12.4b[3]\n" + "ldr q12, [x21, #0x60]\n" + ".inst 0x4f89e2da // sdot v26.4s, v22.16b, v9.4b[0]\n" + ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n" + ".inst 0x4f89eac2 // sdot v2.4s, v22.16b, v9.4b[2]\n" + ".inst 0x4fa9eadd // sdot v29.4s, v22.16b, v9.4b[3]\n" + "ldr q17, [x21, #0x70]\n" + "add x21, x21, #0x88\n" + ".inst 0x4f8ce37a // sdot v26.4s, v27.16b, v12.4b[0]\n" + ".inst 0x4face36a // sdot v10.4s, v27.16b, v12.4b[1]\n" + ".inst 0x4f8ceb62 // sdot v2.4s, v27.16b, v12.4b[2]\n" + ".inst 0x4faceb7d // sdot v29.4s, v27.16b, v12.4b[3]\n" + ".inst 0x4f91e3da // sdot v26.4s, v30.16b, v17.4b[0]\n" + ".inst 0x4fb1e3ca // sdot v10.4s, v30.16b, v17.4b[1]\n" + ".inst 0x4f91ebc2 // sdot v2.4s, v30.16b, v17.4b[2]\n" + ".inst 0x4fb1ebdd // sdot v29.4s, v30.16b, v17.4b[3]\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "fmla v5.4s, v26.4s, v24.4s\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "fmla v21.4s, v10.4s, v31.4s\n" + "fmla v8.4s, v2.4s, v6.4s\n" + "fmla v1.4s, v29.4s, v20.4s\n" + "bgt 3b\n" + "mov x20, %x[res_ptr]\n" + "subs x27, x27, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "str q15, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q19, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q18, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q14, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q11, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q13, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q16, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q7, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q0, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q4, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q5, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q21, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q8, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q1, [x20, #0x0]\n" + "bne 2b\n" + "mov x20, #0x4\n" + "sub x10, x10, #0x10\n" + "cmp x10, #0x10\n" + "mov %x[res_ptr], x26\n" + "madd %x[a_ptr], x20, x9, %x[a_ptr]\n" + "bge 1b\n" + "4:" // Row loop skip + "cbz x10, 9f\n" + "5:" // Row tail: Row loop + "add x24, %x[b_ptr], #0x8\n" + "mov x23, %x[nc]\n" + "add x22, %x[res_ptr], %x[res_stride], LSL #2\n" + "6:" // Row tail: Column loop + "movi v15.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "add x25, %x[a_ptr], #0x8\n" + "mov x21, %x[nb]\n" + "movi v18.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "7:" // Row tail: Block loop + "ldr q7, [x24, #0x0]\n" + "ldr q5, [x25, #0x0]\n" + "movi v9.16b, #0x4\n" + "movi v4.4s, #0x0\n" + "ldr q3, [x24, #0x10]\n" + "ldr q2, [x25, #0x10]\n" + "movi v1.4s, #0x0\n" + "movi v0.4s, #0x0\n" + "ldr q13, [x24, #0x20]\n" + "ldr q31, [x25, #0x20]\n" + "movi v30.4s, #0x0\n" + "movi v29.16b, #0xf0\n" + "ldr q28, [x24, #0x30]\n" + "ldr q27, [x25, #0x30]\n" + "sshl v20.16b, v7.16b, v9.16b\n" + "sub x20, x24, #0x8\n" + "ldr q26, [x25, #0x40]\n" + "ldr q25, [x25, #0x50]\n" + "sshl v17.16b, v3.16b, v9.16b\n" + "and v7.16b, v7.16b, v29.16b\n" + "ldr q24, [x25, #0x60]\n" + "ldr q16, [x25, #0x70]\n" + "sshl v22.16b, v13.16b, v9.16b\n" + "and v3.16b, v3.16b, v29.16b\n" + "ldr d21, [x20, #0x0]\n" + "ldr d12, [x25, #-0x8]\n" + ".inst 0x4f85e284 // sdot v4.4s, v20.16b, v5.4b[0]\n" + ".inst 0x4fa5e281 // sdot v1.4s, v20.16b, v5.4b[1]\n" + ".inst 0x4f85ea80 // sdot v0.4s, v20.16b, v5.4b[2]\n" + ".inst 0x4fa5ea9e // sdot v30.4s, v20.16b, v5.4b[3]\n" + "sshl v9.16b, v28.16b, v9.16b\n" + "subs x21, x21, #0x1\n" + "and v13.16b, v13.16b, v29.16b\n" + "and v28.16b, v28.16b, v29.16b\n" + "add x25, x25, #0x88\n" + "add x24, x24, #0x48\n" + "fcvtl v21.4s, v21.4h\n" + "fcvtl v12.4s, v12.4h\n" + ".inst 0x4f82e224 // sdot v4.4s, v17.16b, v2.4b[0]\n" + ".inst 0x4fa2e221 // sdot v1.4s, v17.16b, v2.4b[1]\n" + ".inst 0x4f82ea20 // sdot v0.4s, v17.16b, v2.4b[2]\n" + ".inst 0x4fa2ea3e // sdot v30.4s, v17.16b, v2.4b[3]\n" + "fmul v11.4s, v21.4s, v12.s[0]\n" + "fmul v23.4s, v21.4s, v12.s[1]\n" + "fmul v17.4s, v21.4s, v12.s[2]\n" + ".inst 0x4f9fe2c4 // sdot v4.4s, v22.16b, v31.4b[0]\n" + "fmul v6.4s, v21.4s, v12.s[3]\n" + ".inst 0x4fbfe2c1 // sdot v1.4s, v22.16b, v31.4b[1]\n" + ".inst 0x4f9feac0 // sdot v0.4s, v22.16b, v31.4b[2]\n" + ".inst 0x4fbfeade // sdot v30.4s, v22.16b, v31.4b[3]\n" + ".inst 0x4f9be124 // sdot v4.4s, v9.16b, v27.4b[0]\n" + ".inst 0x4fbbe121 // sdot v1.4s, v9.16b, v27.4b[1]\n" + ".inst 0x4f9be920 // sdot v0.4s, v9.16b, v27.4b[2]\n" + ".inst 0x4fbbe93e // sdot v30.4s, v9.16b, v27.4b[3]\n" + ".inst 0x4f9ae0e4 // sdot v4.4s, v7.16b, v26.4b[0]\n" + ".inst 0x4fbae0e1 // sdot v1.4s, v7.16b, v26.4b[1]\n" + ".inst 0x4f9ae8e0 // sdot v0.4s, v7.16b, v26.4b[2]\n" + ".inst 0x4fbae8fe // sdot v30.4s, v7.16b, v26.4b[3]\n" + ".inst 0x4f99e064 // sdot v4.4s, v3.16b, v25.4b[0]\n" + ".inst 0x4fb9e061 // sdot v1.4s, v3.16b, v25.4b[1]\n" + ".inst 0x4f99e860 // sdot v0.4s, v3.16b, v25.4b[2]\n" + ".inst 0x4fb9e87e // sdot v30.4s, v3.16b, v25.4b[3]\n" + ".inst 0x4f98e1a4 // sdot v4.4s, v13.16b, v24.4b[0]\n" + ".inst 0x4fb8e1a1 // sdot v1.4s, v13.16b, v24.4b[1]\n" + ".inst 0x4f98e9a0 // sdot v0.4s, v13.16b, v24.4b[2]\n" + ".inst 0x4fb8e9be // sdot v30.4s, v13.16b, v24.4b[3]\n" + ".inst 0x4f90e384 // sdot v4.4s, v28.16b, v16.4b[0]\n" + ".inst 0x4fb0e381 // sdot v1.4s, v28.16b, v16.4b[1]\n" + ".inst 0x4f90eb80 // sdot v0.4s, v28.16b, v16.4b[2]\n" + ".inst 0x4fb0eb9e // sdot v30.4s, v28.16b, v16.4b[3]\n" + "scvtf v4.4s, v4.4s, #0x4\n" + "scvtf v1.4s, v1.4s, #0x4\n" + "scvtf v0.4s, v0.4s, #0x4\n" + "fmla v15.4s, v4.4s, v11.4s\n" + "scvtf v30.4s, v30.4s, #0x4\n" + "fmla v19.4s, v1.4s, v23.4s\n" + "fmla v18.4s, v0.4s, v17.4s\n" + "fmla v14.4s, v30.4s, v6.4s\n" + "bgt 7b\n" + "mov x20, %x[res_ptr]\n" + "cmp x10, #0x1\n" + "str q15, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x2\n" + "str q19, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x3\n" + "str q18, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "str q14, [x20, #0x0]\n" + "8:" // Row tail: Accumulator store skip + "subs x23, x23, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "bne 6b\n" + "subs x10, x10, #0x4\n" + "add %x[a_ptr], %x[a_ptr], x9\n" + "mov %x[res_ptr], x22\n" + "bgt 5b\n" + "9:" // Row tail: Row loop skip + : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) + : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); + return; + } +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) + { + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + +void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + size_t res_stride = bs * sizeof(float); + + __asm__ __volatile__( + "mov x10, %x[nr]\n" + "mov x9, #0x88\n" + "cmp x10, #0x10\n" + "mul x9, %x[nb], x9\n" + "blt 4f\n" + "1:" // Row loop + "add x28, %x[b_ptr], #0x8\n" + "mov x27, %x[nc]\n" + "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" + "2:" // Column loop + "add x25, %x[a_ptr], #0x8\n" + "movi v2.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "mov x24, %x[nb]\n" + "add x23, x25, x9\n" + "movi v12.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "add x22, x23, x9\n" + "movi v11.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "add x21, x22, x9\n" + "movi v22.16b, #0x0\n" + "movi v23.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v5.16b, #0x0\n" + "movi v7.16b, #0x0\n" + "movi v4.16b, #0x0\n" + "movi v6.16b, #0x0\n" + "movi v30.16b, #0x0\n" + "movi v24.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "3:" // Block loop + "ldr q21, [x28, #0x0]\n" + "ldr q16, [x28, #0x10]\n" + "movi v1.16b, #0x4\n" + "movi v19.4s, #0x0\n" + "ldr q27, [x25, #0x0]\n" + "ldr q15, [x25, #0x10]\n" + "movi v26.4s, #0x0\n" + "movi v18.4s, #0x0\n" + "ldr q29, [x28, #0x20]\n" + "ldr q3, [x28, #0x30]\n" + "movi v17.4s, #0x0\n" + "movi v0.16b, #0xf0\n" + "ldr d20, [x25, #-0x8]\n" + "ldr d9, [x23, #-0x8]\n" + "sshl v8.16b, v21.16b, v1.16b\n" + "sshl v31.16b, v16.16b, v1.16b\n" + "and v21.16b, v21.16b, v0.16b\n" + "and v16.16b, v16.16b, v0.16b\n" + "sub x20, x28, #0x8\n" + "subs x24, x24, #0x1\n" + "add x28, x28, #0x48\n" + ".inst 0x4e88a773 // smmla v19.4s, v27.16b, v8.16b\n" + ".inst 0x4e9fa77a // smmla v26.4s, v27.16b, v31.16b\n" + "ldr q27, [x25, #0x20]\n" + ".inst 0x4e88a5f2 // smmla v18.4s, v15.16b, v8.16b\n" + ".inst 0x4e9fa5f1 // smmla v17.4s, v15.16b, v31.16b\n" + "sshl v15.16b, v29.16b, v1.16b\n" + "sshl v1.16b, v3.16b, v1.16b\n" + "and v29.16b, v29.16b, v0.16b\n" + "and v3.16b, v3.16b, v0.16b\n" + "ldr q0, [x25, #0x30]\n" + "fcvtl v20.4s, v20.4h\n" + ".inst 0x4e8fa773 // smmla v19.4s, v27.16b, v15.16b\n" + "fcvtl v9.4s, v9.4h\n" + ".inst 0x4e81a77a // smmla v26.4s, v27.16b, v1.16b\n" + "ldr q27, [x25, #0x40]\n" + ".inst 0x4e8fa412 // smmla v18.4s, v0.16b, v15.16b\n" + ".inst 0x4e81a411 // smmla v17.4s, v0.16b, v1.16b\n" + "ldr q0, [x25, #0x50]\n" + ".inst 0x4e95a773 // smmla v19.4s, v27.16b, v21.16b\n" + ".inst 0x4e90a77a // smmla v26.4s, v27.16b, v16.16b\n" + "ldr q27, [x25, #0x60]\n" + ".inst 0x4e95a412 // smmla v18.4s, v0.16b, v21.16b\n" + ".inst 0x4e90a411 // smmla v17.4s, v0.16b, v16.16b\n" + "ldr q0, [x25, #0x70]\n" + "add x25, x25, #0x88\n" + ".inst 0x4e9da773 // smmla v19.4s, v27.16b, v29.16b\n" + ".inst 0x4e83a77a // smmla v26.4s, v27.16b, v3.16b\n" + "ldr d27, [x20, #0x0]\n" + ".inst 0x4e9da412 // smmla v18.4s, v0.16b, v29.16b\n" + ".inst 0x4e83a411 // smmla v17.4s, v0.16b, v3.16b\n" + "fcvtl v27.4s, v27.4h\n" + "uzp1 v0.2d, v19.2d, v26.2d\n" + "uzp2 v26.2d, v19.2d, v26.2d\n" + "fmul v19.4s, v27.4s, v20.s[0]\n" + "scvtf v0.4s, v0.4s, #0x4\n" + "scvtf v26.4s, v26.4s, #0x4\n" + "fmla v2.4s, v0.4s, v19.4s\n" + "ldr q19, [x23, #0x0]\n" + "uzp1 v0.2d, v18.2d, v17.2d\n" + "uzp2 v18.2d, v18.2d, v17.2d\n" + "fmul v17.4s, v27.4s, v20.s[1]\n" + "scvtf v0.4s, v0.4s, #0x4\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "fmla v10.4s, v26.4s, v17.4s\n" + "ldr q17, [x23, #0x10]\n" + "fmul v26.4s, v27.4s, v20.s[2]\n" + "fmul v20.4s, v27.4s, v20.s[3]\n" + "fmla v12.4s, v0.4s, v26.4s\n" + "ldr d0, [x22, #-0x8]\n" + "ldr d26, [x21, #-0x8]\n" + "fcvtl v0.4s, v0.4h\n" + "fmla v28.4s, v18.4s, v20.4s\n" + "movi v20.4s, #0x0\n" + "movi v18.4s, #0x0\n" + ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n" + ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n" + "ldr q19, [x23, #0x20]\n" + "fcvtl v26.4s, v26.4h\n" + ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n" + ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n" + "ldr q19, [x23, #0x40]\n" + ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n" + ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n" + "ldr q19, [x23, #0x60]\n" + ".inst 0x4e9da674 // smmla v20.4s, v19.16b, v29.16b\n" + ".inst 0x4e83a672 // smmla v18.4s, v19.16b, v3.16b\n" + "uzp1 v19.2d, v20.2d, v18.2d\n" + "scvtf v19.4s, v19.4s, #0x4\n" + "uzp2 v20.2d, v20.2d, v18.2d\n" + "fmul v18.4s, v27.4s, v9.s[0]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v11.4s, v19.4s, v18.4s\n" + "ldr q18, [x22, #0x0]\n" + "fmul v19.4s, v27.4s, v9.s[1]\n" + "fmla v13.4s, v20.4s, v19.4s\n" + "movi v19.4s, #0x0\n" + "movi v20.4s, #0x0\n" + ".inst 0x4e88a633 // smmla v19.4s, v17.16b, v8.16b\n" + ".inst 0x4e9fa634 // smmla v20.4s, v17.16b, v31.16b\n" + "ldr q17, [x23, #0x30]\n" + ".inst 0x4e8fa633 // smmla v19.4s, v17.16b, v15.16b\n" + ".inst 0x4e81a634 // smmla v20.4s, v17.16b, v1.16b\n" + "ldr q17, [x23, #0x50]\n" + ".inst 0x4e95a633 // smmla v19.4s, v17.16b, v21.16b\n" + ".inst 0x4e90a634 // smmla v20.4s, v17.16b, v16.16b\n" + "ldr q17, [x23, #0x70]\n" + "add x23, x23, #0x88\n" + ".inst 0x4e9da633 // smmla v19.4s, v17.16b, v29.16b\n" + ".inst 0x4e83a634 // smmla v20.4s, v17.16b, v3.16b\n" + "uzp1 v17.2d, v19.2d, v20.2d\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "uzp2 v20.2d, v19.2d, v20.2d\n" + "fmul v19.4s, v27.4s, v9.s[2]\n" + "fmul v9.4s, v27.4s, v9.s[3]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v22.4s, v17.4s, v19.4s\n" + "ldr q17, [x22, #0x10]\n" + "movi v19.4s, #0x0\n" + ".inst 0x4e88a653 // smmla v19.4s, v18.16b, v8.16b\n" + "fmla v23.4s, v20.4s, v9.4s\n" + "movi v20.4s, #0x0\n" + "movi v9.4s, #0x0\n" + ".inst 0x4e9fa654 // smmla v20.4s, v18.16b, v31.16b\n" + "ldr q18, [x22, #0x20]\n" + ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n" + ".inst 0x4e8fa653 // smmla v19.4s, v18.16b, v15.16b\n" + ".inst 0x4e81a654 // smmla v20.4s, v18.16b, v1.16b\n" + "ldr q18, [x22, #0x40]\n" + ".inst 0x4e95a653 // smmla v19.4s, v18.16b, v21.16b\n" + ".inst 0x4e90a654 // smmla v20.4s, v18.16b, v16.16b\n" + "ldr q18, [x22, #0x60]\n" + ".inst 0x4e9da653 // smmla v19.4s, v18.16b, v29.16b\n" + ".inst 0x4e83a654 // smmla v20.4s, v18.16b, v3.16b\n" + "movi v18.4s, #0x0\n" + ".inst 0x4e9fa632 // smmla v18.4s, v17.16b, v31.16b\n" + "ldr q17, [x22, #0x30]\n" + ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n" + ".inst 0x4e81a632 // smmla v18.4s, v17.16b, v1.16b\n" + "ldr q17, [x22, #0x50]\n" + ".inst 0x4e95a629 // smmla v9.4s, v17.16b, v21.16b\n" + ".inst 0x4e90a632 // smmla v18.4s, v17.16b, v16.16b\n" + "ldr q17, [x22, #0x70]\n" + "add x22, x22, #0x88\n" + ".inst 0x4e9da629 // smmla v9.4s, v17.16b, v29.16b\n" + ".inst 0x4e83a632 // smmla v18.4s, v17.16b, v3.16b\n" + "uzp1 v17.2d, v19.2d, v20.2d\n" + "uzp2 v20.2d, v19.2d, v20.2d\n" + "fmul v19.4s, v27.4s, v0.s[0]\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v25.4s, v17.4s, v19.4s\n" + "ldr q19, [x21, #0x0]\n" + "fmul v17.4s, v27.4s, v0.s[1]\n" + "fmla v5.4s, v20.4s, v17.4s\n" + "ldr q17, [x21, #0x10]\n" + "uzp1 v20.2d, v9.2d, v18.2d\n" + "uzp2 v9.2d, v9.2d, v18.2d\n" + "fmul v18.4s, v27.4s, v0.s[2]\n" + "fmul v0.4s, v27.4s, v0.s[3]\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "fmla v7.4s, v20.4s, v18.4s\n" + "movi v20.4s, #0x0\n" + "movi v18.4s, #0x0\n" + ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n" + ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n" + "ldr q19, [x21, #0x20]\n" + "fmla v4.4s, v9.4s, v0.4s\n" + "movi v9.4s, #0x0\n" + "movi v0.4s, #0x0\n" + ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n" + "fmul v8.4s, v27.4s, v26.s[0]\n" + ".inst 0x4e9fa620 // smmla v0.4s, v17.16b, v31.16b\n" + "ldr q17, [x21, #0x30]\n" + ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n" + "fmul v31.4s, v27.4s, v26.s[1]\n" + ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n" + "ldr q19, [x21, #0x40]\n" + ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n" + "fmul v15.4s, v27.4s, v26.s[2]\n" + "fmul v27.4s, v27.4s, v26.s[3]\n" + ".inst 0x4e81a620 // smmla v0.4s, v17.16b, v1.16b\n" + "ldr q1, [x21, #0x50]\n" + ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n" + ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n" + "ldr q26, [x21, #0x60]\n" + ".inst 0x4e95a429 // smmla v9.4s, v1.16b, v21.16b\n" + ".inst 0x4e90a420 // smmla v0.4s, v1.16b, v16.16b\n" + "ldr q21, [x21, #0x70]\n" + "add x21, x21, #0x88\n" + ".inst 0x4e9da754 // smmla v20.4s, v26.16b, v29.16b\n" + ".inst 0x4e83a752 // smmla v18.4s, v26.16b, v3.16b\n" + ".inst 0x4e9da6a9 // smmla v9.4s, v21.16b, v29.16b\n" + ".inst 0x4e83a6a0 // smmla v0.4s, v21.16b, v3.16b\n" + "uzp1 v29.2d, v20.2d, v18.2d\n" + "uzp2 v21.2d, v20.2d, v18.2d\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "uzp1 v18.2d, v9.2d, v0.2d\n" + "uzp2 v16.2d, v9.2d, v0.2d\n" + "scvtf v21.4s, v21.4s, #0x4\n" + "fmla v6.4s, v29.4s, v8.4s\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "scvtf v16.4s, v16.4s, #0x4\n" + "fmla v30.4s, v21.4s, v31.4s\n" + "fmla v24.4s, v18.4s, v15.4s\n" + "fmla v14.4s, v16.4s, v27.4s\n" + "bgt 3b\n" + "mov x20, %x[res_ptr]\n" + "subs x27, x27, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "str q2, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q10, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q12, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q28, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q11, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q13, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q22, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q23, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q25, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q5, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q7, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q4, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q6, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q30, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q24, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "str q14, [x20, #0x0]\n" + "bne 2b\n" + "mov x20, #0x4\n" + "sub x10, x10, #0x10\n" + "cmp x10, #0x10\n" + "mov %x[res_ptr], x26\n" + "madd %x[a_ptr], x20, x9, %x[a_ptr]\n" + "bge 1b\n" + "4:" // Row loop skip + "cbz x10, 9f\n" + "5:" // Row tail: Row loop + "add x24, %x[b_ptr], #0x8\n" + "mov x23, %x[nc]\n" + "add x22, %x[res_ptr], %x[res_stride], LSL #2\n" + "6:" // Row tail: Column loop + "movi v2.16b, #0x0\n" + "movi v10.16b, #0x0\n" + "add x25, %x[a_ptr], #0x8\n" + "mov x21, %x[nb]\n" + "movi v12.16b, #0x0\n" + "movi v28.16b, #0x0\n" + "7:" // Row tail: Block loop + "ldr q6, [x24, #0x0]\n" + "ldr q5, [x24, #0x10]\n" + "movi v17.16b, #0x4\n" + "movi v8.4s, #0x0\n" + "ldr q4, [x25, #0x0]\n" + "ldr q13, [x25, #0x10]\n" + "movi v27.4s, #0x0\n" + "movi v0.4s, #0x0\n" + "ldr q31, [x24, #0x20]\n" + "ldr q14, [x24, #0x30]\n" + "movi v29.4s, #0x0\n" + "movi v22.16b, #0xf0\n" + "ldr q11, [x25, #0x20]\n" + "ldr q23, [x25, #0x30]\n" + "sshl v21.16b, v6.16b, v17.16b\n" + "sshl v16.16b, v5.16b, v17.16b\n" + "ldr q20, [x25, #0x40]\n" + "ldr q26, [x25, #0x50]\n" + "and v6.16b, v6.16b, v22.16b\n" + "and v5.16b, v5.16b, v22.16b\n" + "ldr q25, [x25, #0x60]\n" + "ldr q3, [x25, #0x70]\n" + "sshl v19.16b, v31.16b, v17.16b\n" + "sshl v18.16b, v14.16b, v17.16b\n" + "ldr d17, [x25, #-0x8]\n" + ".inst 0x4e95a488 // smmla v8.4s, v4.16b, v21.16b\n" + ".inst 0x4e90a49b // smmla v27.4s, v4.16b, v16.16b\n" + "and v31.16b, v31.16b, v22.16b\n" + ".inst 0x4e95a5a0 // smmla v0.4s, v13.16b, v21.16b\n" + ".inst 0x4e90a5bd // smmla v29.4s, v13.16b, v16.16b\n" + "and v14.16b, v14.16b, v22.16b\n" + "sub x20, x24, #0x8\n" + "ldr d16, [x20, #0x0]\n" + "subs x21, x21, #0x1\n" + "add x25, x25, #0x88\n" + "fcvtl v17.4s, v17.4h\n" + "add x24, x24, #0x48\n" + ".inst 0x4e93a568 // smmla v8.4s, v11.16b, v19.16b\n" + ".inst 0x4e92a57b // smmla v27.4s, v11.16b, v18.16b\n" + ".inst 0x4e93a6e0 // smmla v0.4s, v23.16b, v19.16b\n" + ".inst 0x4e92a6fd // smmla v29.4s, v23.16b, v18.16b\n" + "fcvtl v16.4s, v16.4h\n" + ".inst 0x4e86a688 // smmla v8.4s, v20.16b, v6.16b\n" + ".inst 0x4e85a69b // smmla v27.4s, v20.16b, v5.16b\n" + "fmul v23.4s, v16.4s, v17.s[0]\n" + "fmul v21.4s, v16.4s, v17.s[1]\n" + "fmul v1.4s, v16.4s, v17.s[2]\n" + "fmul v20.4s, v16.4s, v17.s[3]\n" + ".inst 0x4e86a740 // smmla v0.4s, v26.16b, v6.16b\n" + ".inst 0x4e85a75d // smmla v29.4s, v26.16b, v5.16b\n" + ".inst 0x4e9fa728 // smmla v8.4s, v25.16b, v31.16b\n" + ".inst 0x4e8ea73b // smmla v27.4s, v25.16b, v14.16b\n" + ".inst 0x4e9fa460 // smmla v0.4s, v3.16b, v31.16b\n" + ".inst 0x4e8ea47d // smmla v29.4s, v3.16b, v14.16b\n" + "uzp1 v19.2d, v8.2d, v27.2d\n" + "uzp2 v18.2d, v8.2d, v27.2d\n" + "scvtf v19.4s, v19.4s, #0x4\n" + "uzp1 v17.2d, v0.2d, v29.2d\n" + "uzp2 v16.2d, v0.2d, v29.2d\n" + "scvtf v18.4s, v18.4s, #0x4\n" + "fmla v2.4s, v19.4s, v23.4s\n" + "scvtf v17.4s, v17.4s, #0x4\n" + "scvtf v16.4s, v16.4s, #0x4\n" + "fmla v10.4s, v18.4s, v21.4s\n" + "fmla v12.4s, v17.4s, v1.4s\n" + "fmla v28.4s, v16.4s, v20.4s\n" + "bgt 7b\n" + "mov x20, %x[res_ptr]\n" + "cmp x10, #0x1\n" + "str q2, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x2\n" + "str q10, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x10, #0x3\n" + "str q12, [x20, #0x0]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "str q28, [x20, #0x0]\n" + "8:" // Row tail: Accumulator store skip + "subs x23, x23, #0x4\n" + "add %x[res_ptr], %x[res_ptr], #0x10\n" + "bne 6b\n" + "subs x10, x10, #0x4\n" + "add %x[a_ptr], %x[a_ptr], x9\n" + "mov %x[res_ptr], x22\n" + "bgt 5b\n" + "9:" // Row tail: Row loop skip + : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) + : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" + ); + return; + } +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} + +void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) +#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + if (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) { + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + size_t res_stride = bs * sizeof(float); + + __asm__ __volatile__( + "mov x20, #0x4\n" + "mov x13, %x[nr]\n" + "mov z28.s, #-0x4\n" + "mov x12, #0x88\n" + "ptrue p1.b\n" + "whilelt p0.s, XZR, x20\n" + "cmp x13, #0x10\n" + "mul x12, %x[nb], x12\n" + "blt 4f\n" + "1:" // Row loop + "add x11, %x[b_ptr], #0x10\n" + "mov x10, %x[nc]\n" + "add x9, %x[res_ptr], %x[res_stride], LSL #4\n" + "2:" // Column loop + "add x28, %x[a_ptr], #0x8\n" + "mov z24.b, #0x0\n" + "mov z15.b, #0x0\n" + "mov x27, %x[nb]\n" + "add x26, x28, x12\n" + "mov z12.b, #0x0\n" + "mov z0.b, #0x0\n" + "add x25, x26, x12\n" + "mov z13.b, #0x0\n" + "mov z1.b, #0x0\n" + "add x24, x25, x12\n" + "mov z20.b, #0x0\n" + "mov z25.b, #0x0\n" + "mov z11.b, #0x0\n" + "mov z16.b, #0x0\n" + "mov z19.b, #0x0\n" + "mov z26.b, #0x0\n" + "mov z8.b, #0x0\n" + "mov z29.b, #0x0\n" + "mov z27.b, #0x0\n" + "mov z10.b, #0x0\n" + "3:" // Block loop + "ld1b { z30.b }, p1/Z, [x11]\n" + "ld1b { z21.b }, p1/Z, [x11, #1, MUL VL]\n" + "mov z18.s, #0x0\n" + "mov z7.s, #0x0\n" + "ld1rqb { z3.b }, p1/Z, [x28]\n" + "ld1rqb { z5.b }, p1/Z, [x28, #16]\n" + "mov z9.s, #0x0\n" + "mov z22.s, #0x0\n" + "ld1b { z4.b }, p1/Z, [x11, #2, MUL VL]\n" + "ld1b { z17.b }, p1/Z, [x11, #3, MUL VL]\n" + "sub x20, x11, #0x10\n" + "sub x23, x28, #0x8\n" + "lsl z31.b, z30.b, #0x4\n" + "lsl z6.b, z21.b, #0x4\n" + "ld1h { z23.s }, p1/Z, [x20]\n" + "sub x22, x26, #0x8\n" + "and z30.b, z30.b, #0xf0\n" + "and z21.b, z21.b, #0xf0\n" + "sub x21, x25, #0x8\n" + "sub x20, x24, #0x8\n" + "lsl z14.b, z4.b, #0x4\n" + "lsl z2.b, z17.b, #0x4\n" + "subs x27, x27, #0x1\n" + "add x11, x11, #0x90\n" + ".inst 0x451f9872 // smmla z18.s, z3.b, z31.b\n" + ".inst 0x45069867 // smmla z7.s, z3.b, z6.b\n" + "ld1rqb { z3.b }, p1/Z, [x28, #32]\n" + "and z4.b, z4.b, #0xf0\n" + ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n" + ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n" + "ld1rqb { z5.b }, p1/Z, [x28, #48]\n" + "and z17.b, z17.b, #0xf0\n" + "fcvt z23.s, p1/m, z23.h\n" + ".inst 0x450e9872 // smmla z18.s, z3.b, z14.b\n" + ".inst 0x45029867 // smmla z7.s, z3.b, z2.b\n" + "ld1rqb { z3.b }, p1/Z, [x28, #64]\n" + ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n" + ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n" + "ld1rqb { z5.b }, p1/Z, [x28, #80]\n" + "fscale z23.s, p1/m, z23.s, z28.s\n" + ".inst 0x451e9872 // smmla z18.s, z3.b, z30.b\n" + ".inst 0x45159867 // smmla z7.s, z3.b, z21.b\n" + "ld1rqb { z3.b }, p1/Z, [x28, #96]\n" + ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n" + ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n" + "ld1rqb { z5.b }, p1/Z, [x28, #112]\n" + "add x28, x28, #0x88\n" + ".inst 0x45049872 // smmla z18.s, z3.b, z4.b\n" + ".inst 0x45119867 // smmla z7.s, z3.b, z17.b\n" + "ld1h { z3.s }, p0/Z, [x23]\n" + ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n" + ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n" + "fcvt z3.s, p1/m, z3.h\n" + "uzp1 z5.d, z18.d, z7.d\n" + "uzp2 z18.d, z18.d, z7.d\n" + "mov z3.q, z3.q[0]\n" + "uzp1 z7.d, z9.d, z22.d\n" + "uzp2 z22.d, z9.d, z22.d\n" + "fmul z9.s, z23.s, z3.s[0]\n" + "scvtf z5.s, p1/m, z5.s\n" + "scvtf z18.s, p1/m, z18.s\n" + "scvtf z7.s, p1/m, z7.s\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z24.s, p1/M, z5.s, z9.s\n" + "ld1rqb { z5.b }, p1/Z, [x26]\n" + "fmul z9.s, z23.s, z3.s[1]\n" + "fmla z15.s, p1/M, z18.s, z9.s\n" + "ld1rqb { z18.b }, p1/Z, [x26, #16]\n" + "fmul z9.s, z23.s, z3.s[2]\n" + "fmul z3.s, z23.s, z3.s[3]\n" + "fmla z12.s, p1/M, z7.s, z9.s\n" + "mov z9.s, #0x0\n" + "ld1h { z7.s }, p0/Z, [x22]\n" + ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n" + "fmla z0.s, p1/M, z22.s, z3.s\n" + "mov z22.s, #0x0\n" + "ld1h { z3.s }, p0/Z, [x21]\n" + ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n" + "ld1rqb { z5.b }, p1/Z, [x26, #32]\n" + "fcvt z7.s, p1/m, z7.h\n" + "fcvt z3.s, p1/m, z3.h\n" + ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n" + ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n" + "ld1rqb { z5.b }, p1/Z, [x26, #64]\n" + "mov z7.q, z7.q[0]\n" + "mov z3.q, z3.q[0]\n" + ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n" + ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n" + "ld1rqb { z5.b }, p1/Z, [x26, #96]\n" + ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n" + ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n" + "uzp1 z5.d, z9.d, z22.d\n" + "scvtf z5.s, p1/m, z5.s\n" + "uzp2 z22.d, z9.d, z22.d\n" + "fmul z9.s, z23.s, z7.s[0]\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z13.s, p1/M, z5.s, z9.s\n" + "ld1rqb { z9.b }, p1/Z, [x25]\n" + "fmul z5.s, z23.s, z7.s[1]\n" + "fmla z1.s, p1/M, z22.s, z5.s\n" + "mov z5.s, #0x0\n" + "mov z22.s, #0x0\n" + ".inst 0x451f9a45 // smmla z5.s, z18.b, z31.b\n" + ".inst 0x45069a56 // smmla z22.s, z18.b, z6.b\n" + "ld1rqb { z18.b }, p1/Z, [x26, #48]\n" + ".inst 0x450e9a45 // smmla z5.s, z18.b, z14.b\n" + ".inst 0x45029a56 // smmla z22.s, z18.b, z2.b\n" + "ld1rqb { z18.b }, p1/Z, [x26, #80]\n" + ".inst 0x451e9a45 // smmla z5.s, z18.b, z30.b\n" + ".inst 0x45159a56 // smmla z22.s, z18.b, z21.b\n" + "ld1rqb { z18.b }, p1/Z, [x26, #112]\n" + "add x26, x26, #0x88\n" + ".inst 0x45049a45 // smmla z5.s, z18.b, z4.b\n" + ".inst 0x45119a56 // smmla z22.s, z18.b, z17.b\n" + "uzp1 z18.d, z5.d, z22.d\n" + "scvtf z18.s, p1/m, z18.s\n" + "uzp2 z22.d, z5.d, z22.d\n" + "fmul z5.s, z23.s, z7.s[2]\n" + "fmul z7.s, z23.s, z7.s[3]\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z20.s, p1/M, z18.s, z5.s\n" + "ld1rqb { z18.b }, p1/Z, [x25, #16]\n" + "ld1h { z5.s }, p0/Z, [x20]\n" + "fcvt z5.s, p1/m, z5.h\n" + "fmla z25.s, p1/M, z22.s, z7.s\n" + "mov z22.s, #0x0\n" + "mov z7.s, #0x0\n" + ".inst 0x451f9936 // smmla z22.s, z9.b, z31.b\n" + ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n" + "ld1rqb { z9.b }, p1/Z, [x25, #32]\n" + "mov z5.q, z5.q[0]\n" + ".inst 0x450e9936 // smmla z22.s, z9.b, z14.b\n" + ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n" + "ld1rqb { z9.b }, p1/Z, [x25, #64]\n" + ".inst 0x451e9936 // smmla z22.s, z9.b, z30.b\n" + ".inst 0x45159927 // smmla z7.s, z9.b, z21.b\n" + "ld1rqb { z9.b }, p1/Z, [x25, #96]\n" + ".inst 0x45049936 // smmla z22.s, z9.b, z4.b\n" + ".inst 0x45119927 // smmla z7.s, z9.b, z17.b\n" + "uzp1 z9.d, z22.d, z7.d\n" + "scvtf z9.s, p1/m, z9.s\n" + "uzp2 z22.d, z22.d, z7.d\n" + "fmul z7.s, z23.s, z3.s[0]\n" + "scvtf z22.s, p1/m, z22.s\n" + "fmla z11.s, p1/M, z9.s, z7.s\n" + "ld1rqb { z9.b }, p1/Z, [x24]\n" + "fmul z7.s, z23.s, z3.s[1]\n" + "fmla z16.s, p1/M, z22.s, z7.s\n" + "mov z22.s, #0x0\n" + "mov z7.s, #0x0\n" + ".inst 0x451f9a56 // smmla z22.s, z18.b, z31.b\n" + ".inst 0x45069a47 // smmla z7.s, z18.b, z6.b\n" + "ld1rqb { z18.b }, p1/Z, [x25, #48]\n" + ".inst 0x450e9a56 // smmla z22.s, z18.b, z14.b\n" + ".inst 0x45029a47 // smmla z7.s, z18.b, z2.b\n" + "ld1rqb { z18.b }, p1/Z, [x25, #80]\n" + ".inst 0x451e9a56 // smmla z22.s, z18.b, z30.b\n" + ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n" + "ld1rqb { z18.b }, p1/Z, [x25, #112]\n" + "add x25, x25, #0x88\n" + ".inst 0x45049a56 // smmla z22.s, z18.b, z4.b\n" + ".inst 0x45119a47 // smmla z7.s, z18.b, z17.b\n" + "uzp1 z18.d, z22.d, z7.d\n" + "scvtf z18.s, p1/m, z18.s\n" + "uzp2 z7.d, z22.d, z7.d\n" + "fmul z22.s, z23.s, z3.s[2]\n" + "fmul z3.s, z23.s, z3.s[3]\n" + "scvtf z7.s, p1/m, z7.s\n" + "fmla z19.s, p1/M, z18.s, z22.s\n" + "ld1rqb { z18.b }, p1/Z, [x24, #16]\n" + "fmul z22.s, z23.s, z5.s[0]\n" + "fmla z26.s, p1/M, z7.s, z3.s\n" + "mov z3.s, #0x0\n" + "mov z7.s, #0x0\n" + ".inst 0x451f9923 // smmla z3.s, z9.b, z31.b\n" + ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n" + "ld1rqb { z9.b }, p1/Z, [x24, #32]\n" + ".inst 0x450e9923 // smmla z3.s, z9.b, z14.b\n" + ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n" + "mov z9.s, #0x0\n" + ".inst 0x451f9a49 // smmla z9.s, z18.b, z31.b\n" + "mov z31.s, #0x0\n" + ".inst 0x45069a5f // smmla z31.s, z18.b, z6.b\n" + "ld1rqb { z6.b }, p1/Z, [x24, #48]\n" + "ld1rqb { z18.b }, p1/Z, [x24, #64]\n" + ".inst 0x450e98c9 // smmla z9.s, z6.b, z14.b\n" + "fmul z14.s, z23.s, z5.s[1]\n" + ".inst 0x450298df // smmla z31.s, z6.b, z2.b\n" + "ld1rqb { z6.b }, p1/Z, [x24, #80]\n" + "fmul z2.s, z23.s, z5.s[2]\n" + "fmul z23.s, z23.s, z5.s[3]\n" + ".inst 0x451e9a43 // smmla z3.s, z18.b, z30.b\n" + ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n" + "ld1rqb { z5.b }, p1/Z, [x24, #96]\n" + ".inst 0x451e98c9 // smmla z9.s, z6.b, z30.b\n" + ".inst 0x451598df // smmla z31.s, z6.b, z21.b\n" + "ld1rqb { z18.b }, p1/Z, [x24, #112]\n" + "add x24, x24, #0x88\n" + ".inst 0x450498a3 // smmla z3.s, z5.b, z4.b\n" + ".inst 0x451198a7 // smmla z7.s, z5.b, z17.b\n" + ".inst 0x45049a49 // smmla z9.s, z18.b, z4.b\n" + ".inst 0x45119a5f // smmla z31.s, z18.b, z17.b\n" + "uzp1 z18.d, z3.d, z7.d\n" + "uzp2 z5.d, z3.d, z7.d\n" + "scvtf z18.s, p1/m, z18.s\n" + "uzp1 z6.d, z9.d, z31.d\n" + "uzp2 z9.d, z9.d, z31.d\n" + "scvtf z5.s, p1/m, z5.s\n" + "fmla z8.s, p1/M, z18.s, z22.s\n" + "scvtf z6.s, p1/m, z6.s\n" + "scvtf z9.s, p1/m, z9.s\n" + "fmla z29.s, p1/M, z5.s, z14.s\n" + "fmla z27.s, p1/M, z6.s, z2.s\n" + "fmla z10.s, p1/M, z9.s, z23.s\n" + "bgt 3b\n" + "mov x20, %x[res_ptr]\n" + "subs x10, x10, #0x8\n" + "add %x[res_ptr], %x[res_ptr], #0x20\n" + "st1w { z24.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z15.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z12.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z0.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z13.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z1.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z20.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z25.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z11.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z16.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z19.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z26.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z8.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z29.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z27.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "st1w { z10.s }, p1, [x20]\n" + "bne 2b\n" + "mov x20, #0x4\n" + "sub x13, x13, #0x10\n" + "cmp x13, #0x10\n" + "mov %x[res_ptr], x9\n" + "madd %x[a_ptr], x20, x12, %x[a_ptr]\n" + "bge 1b\n" + "4:" // Row loop skip + "cbz x13, 9f\n" + "5:" // Row tail: Row loop + "add x25, %x[b_ptr], #0x10\n" + "mov x24, %x[nc]\n" + "add x23, %x[res_ptr], %x[res_stride], LSL #2\n" + "6:" // Row tail: Column loop + "mov z24.b, #0x0\n" + "mov z15.b, #0x0\n" + "add x28, %x[a_ptr], #0x8\n" + "mov x22, %x[nb]\n" + "mov z12.b, #0x0\n" + "mov z0.b, #0x0\n" + "7:" // Row tail: Block loop + "ld1b { z3.b }, p1/Z, [x25]\n" + "ld1b { z6.b }, p1/Z, [x25, #1, MUL VL]\n" + "mov z2.s, #0x0\n" + "mov z25.s, #0x0\n" + "ld1rqb { z26.b }, p1/Z, [x28]\n" + "ld1rqb { z21.b }, p1/Z, [x28, #16]\n" + "mov z27.s, #0x0\n" + "mov z19.s, #0x0\n" + "ld1b { z29.b }, p1/Z, [x25, #2, MUL VL]\n" + "ld1b { z16.b }, p1/Z, [x25, #3, MUL VL]\n" + "sub x21, x25, #0x10\n" + "sub x20, x28, #0x8\n" + "lsl z20.b, z3.b, #0x4\n" + "lsl z4.b, z6.b, #0x4\n" + "ld1rqb { z10.b }, p1/Z, [x28, #32]\n" + "ld1rqb { z23.b }, p1/Z, [x28, #48]\n" + "and z3.b, z3.b, #0xf0\n" + "and z6.b, z6.b, #0xf0\n" + "ld1rqb { z11.b }, p1/Z, [x28, #64]\n" + "ld1rqb { z7.b }, p1/Z, [x28, #80]\n" + "lsl z8.b, z29.b, #0x4\n" + "lsl z14.b, z16.b, #0x4\n" + "ld1rqb { z18.b }, p1/Z, [x28, #96]\n" + "ld1rqb { z30.b }, p1/Z, [x28, #112]\n" + ".inst 0x45149b42 // smmla z2.s, z26.b, z20.b\n" + ".inst 0x45049b59 // smmla z25.s, z26.b, z4.b\n" + "and z29.b, z29.b, #0xf0\n" + "ld1h { z17.s }, p1/Z, [x21]\n" + ".inst 0x45149abb // smmla z27.s, z21.b, z20.b\n" + ".inst 0x45049ab3 // smmla z19.s, z21.b, z4.b\n" + "and z16.b, z16.b, #0xf0\n" + "ld1h { z4.s }, p0/Z, [x20]\n" + "subs x22, x22, #0x1\n" + "add x28, x28, #0x88\n" + "fcvt z17.s, p1/m, z17.h\n" + "add x25, x25, #0x90\n" + ".inst 0x45089942 // smmla z2.s, z10.b, z8.b\n" + ".inst 0x450e9959 // smmla z25.s, z10.b, z14.b\n" + "fcvt z4.s, p1/m, z4.h\n" + ".inst 0x45089afb // smmla z27.s, z23.b, z8.b\n" + ".inst 0x450e9af3 // smmla z19.s, z23.b, z14.b\n" + "fscale z17.s, p1/m, z17.s, z28.s\n" + "mov z4.q, z4.q[0]\n" + ".inst 0x45039962 // smmla z2.s, z11.b, z3.b\n" + ".inst 0x45069979 // smmla z25.s, z11.b, z6.b\n" + "fmul z23.s, z17.s, z4.s[0]\n" + "fmul z9.s, z17.s, z4.s[1]\n" + "fmul z21.s, z17.s, z4.s[2]\n" + "fmul z4.s, z17.s, z4.s[3]\n" + ".inst 0x450398fb // smmla z27.s, z7.b, z3.b\n" + ".inst 0x450698f3 // smmla z19.s, z7.b, z6.b\n" + ".inst 0x451d9a42 // smmla z2.s, z18.b, z29.b\n" + ".inst 0x45109a59 // smmla z25.s, z18.b, z16.b\n" + ".inst 0x451d9bdb // smmla z27.s, z30.b, z29.b\n" + ".inst 0x45109bd3 // smmla z19.s, z30.b, z16.b\n" + "uzp1 z31.d, z2.d, z25.d\n" + "uzp2 z13.d, z2.d, z25.d\n" + "scvtf z31.s, p1/m, z31.s\n" + "uzp1 z17.d, z27.d, z19.d\n" + "uzp2 z18.d, z27.d, z19.d\n" + "scvtf z13.s, p1/m, z13.s\n" + "fmla z24.s, p1/M, z31.s, z23.s\n" + "scvtf z17.s, p1/m, z17.s\n" + "scvtf z18.s, p1/m, z18.s\n" + "fmla z15.s, p1/M, z13.s, z9.s\n" + "fmla z12.s, p1/M, z17.s, z21.s\n" + "fmla z0.s, p1/M, z18.s, z4.s\n" + "bgt 7b\n" + "mov x20, %x[res_ptr]\n" + "cmp x13, #0x1\n" + "st1w { z24.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x13, #0x2\n" + "st1w { z15.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "cmp x13, #0x3\n" + "st1w { z12.s }, p1, [x20]\n" + "add x20, x20, %x[res_stride]\n" + "ble 8f\n" + "st1w { z0.s }, p1, [x20]\n" + "8:" // Row tail: Accumulator store skip + "subs x24, x24, #0x8\n" + "add %x[res_ptr], %x[res_ptr], #0x20\n" + "bne 6b\n" + "subs x13, x13, #0x4\n" + "add %x[a_ptr], %x[a_ptr], x12\n" + "mov %x[res_ptr], x23\n" + "bgt 5b\n" + "9:" // Row tail: Row loop skip + : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) + : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) + : "cc", "memory", "p0", "p1", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" + ); + return; + } +#endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) + float sumf[4][8]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} + +void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl); + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + + float32x4_t sumf[4]; + for (int m = 0; m < 4; m++) { + sumf[m] = vdupq_n_f32(0); + } + + for (int l = 0; l < nb; l++) { + float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d)); + float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d)); + + int32x4_t sumi_0 = vdupq_n_s32(0); + int32x4_t sumi_1 = vdupq_n_s32(0); + int32x4_t sumi_2 = vdupq_n_s32(0); + int32x4_t sumi_3 = vdupq_n_s32(0); + + for (int k = 0; k < 4; k++) { + int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0); + int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64); + + uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k); + int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4); + int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF); + + sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0); + sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1); + sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2); + sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3); + sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0); + sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1); + sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2); + sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3); + } + + sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0)); + sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1)); + sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2)); + sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3)); + } + + for (int m = 0; m < 4; m++) { + vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]); + } + } + } + return; + } +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) + { + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])); + } + sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} diff --git a/ggml/src/ggml-cpu/arch/loongarch/quants.c b/ggml/src/ggml-cpu/arch/loongarch/quants.c new file mode 100644 index 000000000..f2ea96572 --- /dev/null +++ b/ggml/src/ggml-cpu/arch/loongarch/quants.c @@ -0,0 +1,2638 @@ +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" +#include "ggml-quants.h" +#include "ggml-impl.h" +#include "ggml-cpu.h" + +#include "../../quants.h" +#include "../../ggml-cpu-impl.h" + +#include +#include +#include +#include +#include // for qsort +#include // for GGML_ASSERT + +#define GROUP_MAX_EPS 1e-15f +#define GROUP_MAX_EPS_IQ3_XXS 1e-8f +#define GROUP_MAX_EPS_IQ2_S 1e-8f +#define GROUP_MAX_EPS_IQ1_M 1e-7f +#define GROUP_MAX_EPS_IQ1_S 1e-12f + +#define UNUSED GGML_UNUSED + +#if defined(__loongarch_sx) + +static __m128i lsx_packs_w(__m128i a, __m128i b) { + __m128i tmp, tmp1; + tmp = __lsx_vsat_w(a, 15); + tmp1 = __lsx_vsat_w(b, 15); + return __lsx_vpickev_h(tmp1, tmp); +} + +static __m128i lsx_packs_h(__m128i a, __m128i b) { + __m128i tmp, tmp1; + tmp = __lsx_vsat_h(a, 7); + tmp1 = __lsx_vsat_h(b, 7); + return __lsx_vpickev_b(tmp1, tmp); +} + +static __m128i lsx_packus_h(__m128i a, __m128i b) { + __m128i tmp, tmp1; + tmp = __lsx_vsat_hu(a, 7); + tmp1 = __lsx_vsat_hu(b, 7); + return __lsx_vpickev_b(tmp1, tmp); +} + +static __m128i lsx_maddubs_h(__m128i a, __m128i b) { + __m128i tmp1, tmp2; + tmp1 = __lsx_vmulwev_h_b(a, b); + tmp2 = __lsx_vmulwod_h_b(a, b); + return __lsx_vsadd_h(tmp1, tmp2); +} + +static __m128i lsx_madd_h(__m128i a, __m128i b) { + __m128i tmp1, tmp2; + tmp1 = __lsx_vmulwev_w_h(a, b); + tmp2 = __lsx_vmulwod_w_h(a, b); + return __lsx_vadd_w(tmp1, tmp2); +} + +static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) { + v4i32 __ret = {d, c, b, a}; + return (__m128i)__ret; +} + +static __m128i lsx_shuffle_b(__m128i a, __m128i b) { + __m128i mask_f, zero, tmp0, tmp2, mask; + int f = 0x8f; + mask_f = __lsx_vreplgr2vr_b(f); + zero = __lsx_vldi(0); + tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits + tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive + mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask + tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones + return __lsx_vshuf_b(a, zero, tmp2); +} + +static __m128i lsx_hadd_h(__m128i a, __m128i b) { + __m128i tmp1 = __lsx_vpickev_h(b, a); + __m128i tmp2 = __lsx_vpickod_h(b, a); + return __lsx_vadd_h(tmp1, tmp2); +} + +static __m128i lsx_hadd_w(__m128i a, __m128i b) { + __m128i tmp1 = __lsx_vpickev_w(b, a); + __m128i tmp2 = __lsx_vpickod_w(b, a); + return __lsx_vadd_w(tmp1, tmp2); +} + +static __m128 lsx_hadd_s(__m128 a, __m128 b) { + __m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a); + __m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a); + + return __lsx_vfadd_s(tmp1, tmp2); +} + +static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) { + __m128 res_0 =lsx_hadd_s(a, b); + __m128 res_1 =lsx_hadd_s(c, d); + __m128 res =lsx_hadd_s(res_0, res_1); + res =lsx_hadd_s(res, res); + res =lsx_hadd_s(res, res); + + return ((v4f32)res)[0]; +} +#endif + +#if defined(__loongarch_asx) + +#ifdef __clang__ +#define VREGS_PREFIX "$vr" +#define XREGS_PREFIX "$xr" +#else // GCC +#define VREGS_PREFIX "$f" +#define XREGS_PREFIX "$f" +#endif +#define __ALL_REGS "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31" +// Convert __m128i to __m256i +static inline __m256i ____m256i(__m128i in) { + __m256i out = __lasx_xvldi(0); + __asm__ volatile ( + ".irp i," __ALL_REGS "\n\t" + " .ifc %[out], " XREGS_PREFIX"\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[in], " VREGS_PREFIX "\\j \n\t" + " xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + : [out] "+f" (out) : [in] "f" (in) + ); + return out; +} +// Convert two __m128i to __m256i +static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) { + __m256i out; + __asm__ volatile ( + ".irp i," __ALL_REGS "\n\t" + " .ifc %[hi], " VREGS_PREFIX "\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[lo], " VREGS_PREFIX "\\j \n\t" + " xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + ".ifnc %[out], %[hi] \n\t" + ".irp i," __ALL_REGS "\n\t" + " .ifc %[out], " XREGS_PREFIX "\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[hi], " VREGS_PREFIX "\\j \n\t" + " xvori.b $xr\\i, $xr\\j, 0 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + ".endif \n\t" + : [out] "=f" (out), [hi] "+f" (inhi) + : [lo] "f" (inlo) + ); + return out; +} +// Convert __m256i low part to __m128i +static inline __m128i lasx_extracti128_lo(__m256i in) { + __m128i out; + __asm__ volatile ( + ".ifnc %[out], %[in] \n\t" + ".irp i," __ALL_REGS "\n\t" + " .ifc %[out], " VREGS_PREFIX "\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[in], " XREGS_PREFIX "\\j \n\t" + " vori.b $vr\\i, $vr\\j, 0 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + ".endif \n\t" + : [out] "=f" (out) : [in] "f" (in) + ); + return out; +} +// Convert __m256i high part to __m128i +static inline __m128i lasx_extracti128_hi(__m256i in) { + __m128i out; + __asm__ volatile ( + ".irp i," __ALL_REGS "\n\t" + " .ifc %[out], " VREGS_PREFIX "\\i \n\t" + " .irp j," __ALL_REGS "\n\t" + " .ifc %[in], " XREGS_PREFIX "\\j \n\t" + " xvpermi.q $xr\\i, $xr\\j, 0x11 \n\t" + " .endif \n\t" + " .endr \n\t" + " .endif \n\t" + ".endr \n\t" + : [out] "=f" (out) : [in] "f" (in) + ); + return out; +} + +static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1, int e0) { + v8i32 __ret = {e0, e1, e2, e3, e4, e5, e6, e7}; + return (__m256i)__ret; +} + +static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) { + v4i64 __ret = {d, c, b, a}; + return (__m256i)__ret; +} + +static __m256i lasx_insertf128( __m128i x, __m128i y) { + return lasx_set_q(x, y); +} + +static __m256i lasx_shuffle_b(__m256i a, __m256i b) { + __m256i mask_f, zero, tmp0, tmp2, mask; + int f = 0x8f; + mask_f = __lasx_xvreplgr2vr_b(f); + zero = __lasx_xvldi(0); + tmp0 = __lasx_xvand_v(b, mask_f); // get mask with low 4 bit and sign bits + tmp0 = __lasx_xvori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive + mask = __lasx_xvsle_b(zero, tmp0); // if mask >= 0, set mask + tmp2 = __lasx_xvand_v(tmp0, mask); // maskout the in2 < ones + return __lasx_xvshuf_b(a, zero, tmp2); +} + +static __m256i lasx_extu8_16(__m128i a) { + return __lasx_vext2xv_hu_bu(____m256i(a)); +} + +static __m256i lasx_ext8_16(__m128i a) { + return __lasx_vext2xv_h_b(____m256i(a)); +} + +static __m256i lasx_ext16_32(__m128i a) { + return __lasx_vext2xv_w_h(____m256i(a)); +} + +static __m128i lasx_extracti128( __m256i a, int pos) { + __m128i ret; + if( pos == 0) + { + ret = lasx_extracti128_lo(a); + } else { + ret = lasx_extracti128_hi(a); + } + return ret; +} + +static __m128 lasx_extractf128( __m256 a, int pos) { + __m128 ret; + if( pos == 0) + { + ret = (__m128)lasx_extracti128_lo((__m256i)a); + } else { + ret = (__m128)lasx_extracti128_hi((__m256i)a); + } + return ret; +} + +static __m256i lasx_maddubs_h(__m256i a, __m256i b) { + __m256i tmp1, tmp2; + tmp1 = __lasx_xvmulwev_h_b(a, b); + tmp2 = __lasx_xvmulwod_h_b(a, b); + return __lasx_xvsadd_h(tmp1, tmp2); +} + +static __m256i lasx_madd_h(__m256i a, __m256i b) { + __m256i tmp1, tmp2; + tmp1 = __lasx_xvmulwev_w_h(a, b); + tmp2 = __lasx_xvmulwod_w_h(a, b); + return __lasx_xvadd_w(tmp1, tmp2); +} + +static __m256i lasx_packs_w(__m256i a, __m256i b) { + __m256i tmp, tmp1; + tmp = __lasx_xvsat_w(a, 15); + tmp1 = __lasx_xvsat_w(b, 15); + return __lasx_xvpickev_h(tmp1, tmp); +} + +static __m256i lasx_packs_h(__m256i a, __m256i b) { + __m256i tmp, tmp1; + tmp = __lasx_xvsat_h(a, 7); + tmp1 = __lasx_xvsat_h(b, 7); + return __lasx_xvpickev_b(tmp1, tmp); +} + +static inline __m256i lasx_madd_h_b(__m256i a, __m256i b) { + __m256i tmp1, tmp2; + tmp1 = __lasx_xvmulwev_h_b(a, b); + tmp2 = __lasx_xvmulwod_h_b(a, b); + return __lasx_xvadd_h(tmp1, tmp2); +} + +static inline __m256i lasx_xvrepl128vei_h(__m256i a, const unsigned int b) { + switch (b) { + case 0: return __lasx_xvrepl128vei_h(a, 0); + case 1: return __lasx_xvrepl128vei_h(a, 1); + case 2: return __lasx_xvrepl128vei_h(a, 2); + case 3: return __lasx_xvrepl128vei_h(a, 3); + case 4: return __lasx_xvrepl128vei_h(a, 4); + case 5: return __lasx_xvrepl128vei_h(a, 5); + case 6: return __lasx_xvrepl128vei_h(a, 6); + case 7: return __lasx_xvrepl128vei_h(a, 7); + default: __builtin_unreachable(); + } +} + +static inline __m256i lasx_xvandi_b_bit(__m256i a, const unsigned int b) { + switch (b) { + case 0: return __lasx_xvandi_b(a, 1 << 0); + case 1: return __lasx_xvandi_b(a, 1 << 1); + case 2: return __lasx_xvandi_b(a, 1 << 2); + case 3: return __lasx_xvandi_b(a, 1 << 3); + case 4: return __lasx_xvandi_b(a, 1 << 4); + case 5: return __lasx_xvandi_b(a, 1 << 5); + case 6: return __lasx_xvandi_b(a, 1 << 6); + case 7: return __lasx_xvandi_b(a, 1 << 7); + default: __builtin_unreachable(); + } +} + +// multiply int8_t, add results pairwise twice +static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { + // Get absolute values of x vectors + const __m128i ax = __lsx_vsigncov_b(x, x); + // Sign the values of the y vectors + const __m128i sy = __lsx_vsigncov_b(x, y); + // Perform multiplication and create 16-bit values + const __m128i dot = lsx_maddubs_h(ax, sy); + const __m128i ones = __lsx_vreplgr2vr_h(1); + return lsx_madd_h(ones, dot); +} + +// horizontally add 8 floats +static inline float hsum_float_8(const __m256 x) { + __m128 res = lasx_extractf128(x, 1); + res = __lsx_vfadd_s(res, lasx_extractf128(x, 0)); + res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res)); + res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0)); + return ((v4f32)res)[0]; +} + +// horizontally add 8 int32_t +static inline int hsum_i32_8(const __m256i a) { + + __m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11); + __m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00); + + __m128i tmp1_128 = lasx_extracti128_lo(tmp1); + __m128i tmp2_128 = lasx_extracti128_lo(tmp2); + + __m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128); + + __m128i ev = __lsx_vpickev_w(sum128, sum128); + __m128i od = __lsx_vpickod_w(sum128, sum128); + __m128i sum64 = __lsx_vadd_w(ev, od); + + int sum64_1, sum64_2; + sum64_1 = __lsx_vpickve2gr_w(sum64, 0); + sum64_2 = __lsx_vpickve2gr_w(sum64, 1); + + return sum64_1 + sum64_2; +} + +// horizontally add 4 int32_t +static inline int hsum_i32_4(const __m128i a) { + __m128i ev = __lsx_vpickev_w(a, a); + __m128i od = __lsx_vpickod_w(a, a); + __m128i sum64 = __lsx_vadd_w(ev, od); + + int sum64_1, sum64_2; + sum64_1 = __lsx_vpickve2gr_w(sum64, 0); + sum64_2 = __lsx_vpickve2gr_w(sum64, 1); + + return sum64_1 + sum64_2; +} + +// spread 32 bits to 32 bytes { 0x00, 0xFF } +static inline __m256i bytes_from_bits_32(const uint8_t * x) { + + uint32_t x32; + memcpy(&x32, x, sizeof(uint32_t)); + const __m256i shuf_mask = lasx_set_d( + 0x0303030303030303, 0x0202020202020202, + 0x0101010101010101, 0x0000000000000000); + + __m256i bytes = lasx_shuffle_b(__lasx_xvreplgr2vr_w(x32), shuf_mask); + const __m256i bit_mask = __lasx_xvreplgr2vr_d(0x7fbfdfeff7fbfdfe); + bytes = __lasx_xvor_v(bytes, bit_mask); + return __lasx_xvseq_b(bytes, __lasx_xvreplgr2vr_d(-1)); +} + +// Unpack 32 4-bit fields into 32 bytes +// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval +static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) { + const __m128i lo = __lsx_vld((const __m128i *)rsi, 0); + __m128i hi = __lsx_vsrli_h(lo, 4); + return __lasx_xvandi_b(lasx_insertf128(hi, lo), 0xf); +} + +// add int16_t pairwise and return as float vector +static inline __m256 sum_i16_pairs_float(const __m256i x) { + __m256i v = __lasx_xvpackod_h(x, x); + __m256i summed_pairs = __lasx_xvaddwev_w_h(x, v); + return __lasx_xvffint_s_w(summed_pairs); +} + +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { + // Perform multiplication and create 16-bit values + const __m256i dot = lasx_maddubs_h(ax, sy); + return sum_i16_pairs_float(dot); +} + +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { + const __m256i dot = lasx_madd_h_b(x, y); + return sum_i16_pairs_float(dot); +} + +static inline __m128i packNibbles( __m256i bytes ) { + // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh + const __m256i lowByte = __lasx_xvreplgr2vr_h(0xFF); + __m256i high = __lasx_xvandn_v(lowByte, bytes); + __m256i low = __lasx_xvand_v(lowByte, bytes); + high = __lasx_xvsrli_h(high, 4); + bytes = __lasx_xvor_v(low, high); + // Compress uint16_t lanes into bytes + __m128i *r0 = (__m128i *)&bytes; + __m256i tmp_h128 = __lasx_xvpermi_q(bytes, bytes, 0x11); + __m128i *r1 = (__m128i *)&tmp_h128; + + __m128i zero = __lsx_vldi(0); + __m128i tmp, tmp2, tmp3; + + tmp = __lsx_vmax_h(zero, *r0); + tmp2 = __lsx_vsat_hu(tmp, 7); + + tmp = __lsx_vmax_h(zero, *r1); + tmp3 = __lsx_vsat_hu(tmp, 7); + return __lsx_vpickev_b(tmp3, tmp2); +} +#endif //__loongarch_asx + +void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0 * GGML_RESTRICT y = vy; + +#if defined(__loongarch_asx) + for (int i = 0; i < nb; i++) { + __m256 v0 = (__m256)__lasx_xvld( x , 0); + __m256 v1 = (__m256)__lasx_xvld( x , 32); + __m256 v2 = (__m256)__lasx_xvld( x , 64); + __m256 v3 = (__m256)__lasx_xvld( x , 96); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f ); + __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 ); + max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) ); + max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) ); + max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) ); + + __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs , 0) ); + max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) ); + __m128 tmp = max4; + max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 )); + const float max_scalar = ((v4f32)max4)[0]; + + // Quantize these floats + const float d = max_scalar / 127.f; + y[i].d = GGML_FP32_TO_FP16(d); + const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; + const __m256 mul = (__m256)__lasx_xvreplfr2vr_s( id ); + + // Apply the multiplier + v0 = __lasx_xvfmul_s( v0, mul ); + v1 = __lasx_xvfmul_s( v1, mul ); + v2 = __lasx_xvfmul_s( v2, mul ); + v3 = __lasx_xvfmul_s( v3, mul ); + + // Round to nearest integer + __m256i i0 = __lasx_xvftintrne_w_s( v0 ); + __m256i i1 = __lasx_xvftintrne_w_s( v1 ); + __m256i i2 = __lasx_xvftintrne_w_s( v2 ); + __m256i i3 = __lasx_xvftintrne_w_s( v3 ); + + __m128i ni0 = lasx_extracti128( i0, 0 ); + __m128i ni1 = lasx_extracti128( i0, 1); + __m128i ni2 = lasx_extracti128( i1, 0); + __m128i ni3 = lasx_extracti128( i1, 1); + __m128i ni4 = lasx_extracti128( i2, 0); + __m128i ni5 = lasx_extracti128( i2, 1); + __m128i ni6 = lasx_extracti128( i3, 0); + __m128i ni7 = lasx_extracti128( i3, 1); + + // Convert int32 to int16 + ni0 = lsx_packs_w( ni0, ni1 ); + ni2 = lsx_packs_w( ni2, ni3 ); + ni4 = lsx_packs_w( ni4, ni5 ); + ni6 = lsx_packs_w( ni6, ni7 ); + // Convert int16 to int8 + ni0 = lsx_packs_h( ni0, ni2 ); + ni4 = lsx_packs_h( ni4, ni6 ); + + __lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0); + __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0); + + } +#else + GGML_UNUSED(nb); + // scalar + quantize_row_q8_0_ref(x, y, k); +#endif +} + +void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % QK8_1 == 0); + const int nb = k / QK8_1; + + block_q8_1 * GGML_RESTRICT y = vy; + +#if defined(__loongarch_asx) + for (int i = 0; i < nb; i++) { + __m256 v0 = (__m256)__lasx_xvld( x , 0 ); + __m256 v1 = (__m256)__lasx_xvld( x , 32 ); + __m256 v2 = (__m256)__lasx_xvld( x , 64 ); + __m256 v3 = (__m256)__lasx_xvld( x , 96 ); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f ); + __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 ); + max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) ); + max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) ); + max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) ); + + __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs, 0) ); + max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) ); + __m128 tmp = max4; + max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 )); + const float max_scalar = ((v4f32)max4)[0]; + + // Quantize these floats + const float d = max_scalar / 127.f; + y[i].d = GGML_FP32_TO_FP16(d); + const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; + const __m256 mul = __lasx_xvreplfr2vr_s( id ); + + // Apply the multiplier + v0 = __lasx_xvfmul_s( v0, mul ); + v1 = __lasx_xvfmul_s( v1, mul ); + v2 = __lasx_xvfmul_s( v2, mul ); + v3 = __lasx_xvfmul_s( v3, mul ); + + // Round to nearest integer + __m256i i0 = __lasx_xvftintrne_w_s( v0 ); + __m256i i1 = __lasx_xvftintrne_w_s( v1 ); + __m256i i2 = __lasx_xvftintrne_w_s( v2 ); + __m256i i3 = __lasx_xvftintrne_w_s( v3 ); + + __m128i ni0 = lasx_extracti128(i0, 0); + __m128i ni1 = lasx_extracti128( i0, 1); + __m128i ni2 = lasx_extracti128( i1, 0); + __m128i ni3 = lasx_extracti128( i1, 1); + __m128i ni4 = lasx_extracti128( i2, 0 ); + __m128i ni5 = lasx_extracti128( i2, 1); + __m128i ni6 = lasx_extracti128( i3, 0); + __m128i ni7 = lasx_extracti128( i3, 1); + + // Compute the sum of the quants and set y[i].s + const __m128i s0 = __lsx_vadd_w(__lsx_vadd_w(ni0, ni1), __lsx_vadd_w(ni2, ni3)); + const __m128i s1 = __lsx_vadd_w(__lsx_vadd_w(ni4, ni5), __lsx_vadd_w(ni6, ni7)); + y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_4(__lsx_vadd_w(s0, s1))); + + // Convert int32 to int16 + ni0 = lsx_packs_w( ni0, ni1 ); + ni2 = lsx_packs_w( ni2, ni3 ); + ni4 = lsx_packs_w( ni4, ni5 ); + ni6 = lsx_packs_w( ni6, ni7 ); + // Convert int16 to int8 + ni0 = lsx_packs_h( ni0, ni2 ); + ni4 = lsx_packs_h( ni4, ni6 ); + + __lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0); + __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0); + } +#else + GGML_UNUSED(nb); + // scalar + quantize_row_q8_1_ref(x, y, k); +#endif +} + + +//===================================== Dot products ================================= + +// +// Helper functions +// + +#if defined(__loongarch_asx) +// shuffles to pick the required scales in dot products +static inline __m256i get_scale_shuffle_q3k(int i) { + static const uint8_t k_shuffle[128] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, + }; + return __lasx_xvld((const __m256i*)k_shuffle + i, 0); +} +static inline __m256i get_scale_shuffle_k4(int i) { + static const uint8_t k_shuffle[256] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, + 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, + 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, + 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15 + }; + return __lasx_xvld((const __m256i*)k_shuffle + i, 0); +} +static inline __m128i get_scale_shuffle(int i) { + static const uint8_t k_shuffle[128] = { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, + 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, + 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11, + 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13, + 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15 + }; + return __lsx_vld((const __m128i*)k_shuffle + i, 0); +} +#endif + +void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + int ib = 0; + float sumf = 0; + +#if defined(__loongarch_asx) + // Initialize accumulator with zeros + __m256 acc = (__m256)__lasx_xvldi(0); + + // Main loop + for (; ib < nb; ++ib) { + /* Compute combined scale for the block */ + const __m256 d = __lasx_xvreplfr2vr_s( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) ); + + __m256i qx = bytes_from_nibbles_32(x[ib].qs); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m256i off = __lasx_xvreplgr2vr_b( 8 ); + qx = __lasx_xvsub_b( qx, off ); + + __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); + + const __m256 q = mul_sum_i8_pairs_float(qx, qy); + + /* Multiply q with scale and accumulate */ + acc = __lasx_xvfmadd_s( d, q, acc ); + } + + sumf = hsum_float_8(acc); + +#elif defined(__loongarch_sx) + // set constants + const __m128i low_mask = __lsx_vreplgr2vr_b(0xF); + const __m128i off = __lsx_vreplgr2vr_b(8); + + // Initialize accumulator with zeros + __m128 acc_0 = (__m128)__lsx_vldi(0); + __m128 acc_1 = (__m128)__lsx_vldi(0); + __m128 acc_2 = (__m128)__lsx_vldi(0); + __m128 acc_3 = (__m128)__lsx_vldi(0); + + for (; ib + 1 < nb; ib += 2) { + + // Compute combined scale for the block 0 and 1 + const __m128 d_0_1 = (__m128)__lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) ); + + const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0); + + __m128i bx_0 = __lsx_vand_v(low_mask, tmp_0_1); + __m128i by_0 = __lsx_vld((const __m128i *)y[ib].qs, 0); + bx_0 = __lsx_vsub_b(bx_0, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); + + __m128i bx_1 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_0_1, 4)); + __m128i by_1 = __lsx_vld((const __m128i *)(y[ib].qs + 16), 0); + bx_1 = __lsx_vsub_b(bx_1, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); + + //_mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0); + //_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 2 and 3 + const __m128 d_2_3 = (__m128)__lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) ); + + const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0); + + __m128i bx_2 = __lsx_vand_v(low_mask, tmp_2_3); + __m128i by_2 = __lsx_vld((const __m128i *)y[ib + 1].qs, 0); + bx_2 = __lsx_vsub_b(bx_2, off); + const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); + + __m128i bx_3 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_2_3, 4)); + __m128i by_3 = __lsx_vld((const __m128i *)(y[ib + 1].qs + 16), 0); + bx_3 = __lsx_vsub_b(bx_3, off); + const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); + + // Convert int32_t to float + __m128 p0 = __lsx_vffint_s_w(i32_0); + __m128 p1 = __lsx_vffint_s_w(i32_1); + __m128 p2 = __lsx_vffint_s_w(i32_2); + __m128 p3 = __lsx_vffint_s_w(i32_3); + + // Apply the scale + __m128 p0_d = __lsx_vfmul_s( d_0_1, p0 ); + __m128 p1_d = __lsx_vfmul_s( d_0_1, p1 ); + __m128 p2_d = __lsx_vfmul_s( d_2_3, p2 ); + __m128 p3_d = __lsx_vfmul_s( d_2_3, p3 ); + + // Acummulate + acc_0 = __lsx_vfadd_s(p0_d, acc_0); + acc_1 = __lsx_vfadd_s(p1_d, acc_1); + acc_2 = __lsx_vfadd_s(p2_d, acc_2); + acc_3 = __lsx_vfadd_s(p3_d, acc_3); + } + + sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); + +#endif + for (; ib < nb; ++ib) { + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[ib].qs[j] & 0x0F) - 8; + const int v1 = (x[ib].qs[j] >> 4) - 8; + + sumi0 += (v0 * y[ib].qs[j]); + sumi1 += (v1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d); + } + + *s = sumf; +} + +void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_1 * GGML_RESTRICT x = vx; + const block_q8_1 * GGML_RESTRICT y = vy; + + int ib = 0; + float sumf = 0; + +#if defined(__loongarch_asx) + // Initialize accumulator with zeros + __m256 acc = (__m256)__lasx_xvldi(0); + + float summs = 0; + + // Main loop + for (; ib < nb; ++ib) { + const float d0 = GGML_FP16_TO_FP32(x[ib].d); + const float d1 = GGML_FP16_TO_FP32(y[ib].d); + + summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); + + const __m256 d0v = __lasx_xvreplfr2vr_s( d0 ); + const __m256 d1v = __lasx_xvreplfr2vr_s( d1 ); + + // Compute combined scales + const __m256 d0d1 = __lasx_xvfmul_s( d0v, d1v ); + + // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes + const __m256i qx = bytes_from_nibbles_32(x[ib].qs); + const __m256i qy = __lasx_xvld( (const __m256i *)y[ib].qs, 0); + + const __m256 xy = mul_sum_us8_pairs_float(qx, qy); + + // Accumulate d0*d1*x*y + acc = __lasx_xvfmadd_s( d0d1, xy, acc ); + } + + sumf = hsum_float_8(acc) + summs; + +#endif + for (; ib < nb; ++ib) { + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[ib].qs[j] & 0x0F); + const int v1 = (x[ib].qs[j] >> 4); + + sumi0 += (v0 * y[ib].qs[j]); + sumi1 += (v1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); + } + + *s = sumf; +} + +void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + int ib = 0; + float sumf = 0; + + assert(n % qk == 0); + assert(qk == QK5_0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + +#if defined(__loongarch_asx) + // Initialize accumulator with zeros + __m256 acc = (__m256)__lasx_xvldi(0); + + // Main loop + for (; ib < nb; ++ib) { + /* Compute combined scale for the block */ + const __m256 d = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); //FIXME + + __m256i qx = bytes_from_nibbles_32(x[ib].qs); + __m256i bxhi = bytes_from_bits_32(x[ib].qh); + bxhi = __lasx_xvandn_v(bxhi, __lasx_xvreplgr2vr_b((char)0xF0)); + qx = __lasx_xvor_v(qx, bxhi); + + __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); + + const __m256 q = mul_sum_i8_pairs_float(qx, qy); + + /* Multiply q with scale and accumulate */ + acc = __lasx_xvfmadd_s(d, q, acc); + } + + sumf = hsum_float_8(acc); + +#endif + for (; ib < nb; ++ib) { + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); + + const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16); + const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16); + + sumi0 += (x0 * y[ib].qs[j]); + sumi1 += (x1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi; + } + + *s = sumf; +} + +void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; + + int ib = 0; + float sumf = 0; + + assert(n % qk == 0); + assert(qk == QK5_1); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_1 * GGML_RESTRICT x = vx; + const block_q8_1 * GGML_RESTRICT y = vy; + +#if defined(__loongarch_asx) + // Initialize accumulator with zeros + __m256 acc = (__m256)__lasx_xvldi(0); + + float summs = 0.0f; + + // Main loop + for (; ib < nb; ++ib) { + const __m256 dx = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d)); + + summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); + + __m256i qx = bytes_from_nibbles_32(x[ib].qs); + __m256i bxhi = bytes_from_bits_32(x[ib].qh); + bxhi = __lasx_xvand_v(bxhi, __lasx_xvreplgr2vr_b(0x10)); + qx = __lasx_xvor_v(qx, bxhi); + + const __m256 dy = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib].d)); + const __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); + + const __m256 q = mul_sum_us8_pairs_float(qx, qy); + + acc = __lasx_xvfmadd_s(q, __lasx_xvfmul_s(dx, dy), acc); + } + + sumf = hsum_float_8(acc) + summs; + +#endif + for (; ib < nb; ++ib) { + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; + + const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0; + const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1; + + sumi0 += (x0 * y[ib].qs[j]); + sumi1 += (x1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); + } + + *s = sumf; +} + +void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q8_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + int ib = 0; + float sumf = 0; + +#if defined(__loongarch_asx) + // Initialize accumulator with zeros + __m256 acc = (__m256)__lasx_xvldi(0); + + // Main loop + for (; ib < nb; ++ib) { + // Compute combined scale for the block + const __m256 d = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); + __m256i qx = __lasx_xvld((const __m256i *)x[ib].qs, 0); + __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); + + const __m256 q = mul_sum_i8_pairs_float(qx, qy); + + // Multiply q with scale and accumulate + acc = __lasx_xvfmadd_s( d, q, acc ); + } + + sumf = hsum_float_8(acc); + +#endif + for (; ib < nb; ++ib) { + int sumi = 0; + + for (int j = 0; j < qk; j++) { + sumi += x[ib].qs[j]*y[ib].qs[j]; + } + + sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)); + } + + *s = sumf; +} + +void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q2_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined __loongarch_asx + + __m256 acc = (__m256)__lasx_xvldi(0); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * GGML_RESTRICT q2 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const __m128i mins_and_scales128 = __lsx_vld((const __m128i*)x[i].scales, 0); + const __m128i scales128 = __lsx_vandi_b(mins_and_scales128, 0xf); + const __m256i mins = lasx_ext8_16(__lsx_vsrli_b(mins_and_scales128, 4)); + const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0)); + + acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc); + + const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; + const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask)); + + __m256i sumi = __lasx_xvldi(0); + + for (int j = 0; j < QK_K/128; ++j) { + + const __m256i q2bits = __lasx_xvld((const __m256i*)q2, 0); q2 += 32; + + const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + + const __m256i q2_0 = __lasx_xvandi_b(q2bits, 3); + const __m256i q2_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 2), 3); + const __m256i q2_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 4), 3); + const __m256i q2_3 = __lasx_xvsrli_b(q2bits, 6); + + __m256i p0 = lasx_madd_h_b(q2_0, q8_0); + __m256i p1 = lasx_madd_h_b(q2_1, q8_1); + __m256i p2 = lasx_madd_h_b(q2_2, q8_2); + __m256i p3 = lasx_madd_h_b(q2_3, q8_3); + + p0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p0); + p1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p1); + p2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p2); + p3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p3); + + p0 = __lasx_xvadd_w(p0, p1); + p2 = __lasx_xvadd_w(p2, p3); + + sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p0, p2)); + } + + acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc); + + } + + *s = hsum_float_8(acc); + +#else + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + int summs = 0; + for (int j = 0; j < 16; ++j) { + summs += y[i].bsums[j] * (sc[j] >> 4); + } + + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + int isum = 0; + int is = 0; + int d; + for (int k = 0; k < QK_K/128; ++k) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + d = sc[is++] & 0xF; + int isuml = 0; + for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + d = sc[is++] & 0xF; + isuml = 0; + for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + shift += 2; + q8 += 32; + } + q2 += 32; + } + sumf += dall * isum - dmin * summs; + } + *s = sumf; +#endif +} + +void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined __loongarch_asx + + const __m128i m32 = __lsx_vreplgr2vr_b(32); + + __m256 acc = (__m256)__lasx_xvldi(0); + + uint32_t aux[3]; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + // Set up scales + memcpy(aux, x[i].scales, 12); + __m128i scales128 = lsx_set_w( + ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), + ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), + (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), + (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); + scales128 = __lsx_vsub_b(scales128, m32); + + const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; + const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask)); + + // high bit + const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0); + + // integer accumulator + __m256i sumi = __lasx_xvldi(0); + + for (int j = 0; j < QK_K/128; ++j) { + // load low 2 bits + const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32; + + // prepare low and high bits + const __m256i q3l_0 = __lasx_xvandi_b(q3bits, 3); + const __m256i q3l_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 2), 3); + const __m256i q3l_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 4), 3); + const __m256i q3l_3 = __lasx_xvsrli_b(q3bits, 6); + const __m256i q3h_0 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 0), 0), 2); + const __m256i q3h_1 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 1), 0), 2); + const __m256i q3h_2 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 2), 0), 2); + const __m256i q3h_3 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 3), 0), 2); + const __m256i q3_0 = __lasx_xvor_v(q3h_0, q3l_0); + const __m256i q3_1 = __lasx_xvor_v(q3h_1, q3l_1); + const __m256i q3_2 = __lasx_xvor_v(q3h_2, q3l_2); + const __m256i q3_3 = __lasx_xvor_v(q3h_3, q3l_3); + + // load Q8 quants + const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + + __m256i p16_0 = lasx_madd_h_b(q8_0, q3_0); + __m256i p16_1 = lasx_madd_h_b(q8_1, q3_1); + __m256i p16_2 = lasx_madd_h_b(q8_2, q3_2); + __m256i p16_3 = lasx_madd_h_b(q8_3, q3_3); + + // multiply with scales + p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0); + p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1); + p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2); + p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3); + + // accumulate + p16_0 = __lasx_xvadd_w(p16_0, p16_1); + p16_2 = __lasx_xvadd_w(p16_2, p16_3); + sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2)); + } + // multiply with block scale and accumulate + acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc); + } + + *s = hsum_float_8(acc); + +#else + // scalar version + // This function is written like this so the compiler can manage to vectorize most of it + // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the + // manually vectorized version above. Every other version I tried would run at least 4 times slower. + // The ideal situation would be if we could just write the code once, and the compiler would + // automatically produce the best possible set of machine instructions, instead of us having to manually + // write vectorized versions for AVX, ARM_NEON, etc. + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + uint32_t auxs[4]; + const int8_t * scales = (const int8_t*)auxs; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT hm = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + q3 += 32; + } + a = aux8; + + memcpy(auxs, x[i].scales, 12); + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; + +#endif + +} + +void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#if defined __loongarch_asx + + __m256 acc = (__m256)__lasx_xvldi(0); + __m128 acc_m = (__m128)__lsx_vldi(0); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]); + const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128); + const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0); + + const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0); + const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1)); + const __m128i prod = lsx_madd_h(mins128, q8s); + acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m); + + const __m256i scales = lasx_insertf128(scales128, scales128); + + __m256i sumi = __lasx_xvldi(0); + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_l = lasx_xvrepl128vei_h(scales, 2 * j + 0); + const __m256i scale_h = lasx_xvrepl128vei_h(scales, 2 * j + 1); + + const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; + const __m256i q4l = __lasx_xvandi_b(q4bits, 0xf); + const __m256i q4h = __lasx_xvsrli_b(q4bits, 4); + + const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + __m256i p16l = lasx_madd_h_b(q4l, q8l); + p16l = lasx_madd_h(scale_l, p16l); + + const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + __m256i p16h = lasx_madd_h_b(q4h, q8h); + p16h = lasx_madd_h(scale_h, p16h); + const __m256i sumj = __lasx_xvadd_w(p16l, p16h); + + sumi = __lasx_xvadd_w(sumi, sumj); + } + + __m256 vd = __lasx_xvreplfr2vr_s(d); + acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc); + + } + + acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vpermi_w((__m128i)acc_m, (__m128i)acc_m, 0xee)); + __m128i tmp1 = __lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w((__m128i)acc_m, 1), 0); + acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1); + + + *s = hsum_float_8(acc) + ((v4f32)acc_m)[0]; + +#else + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + a += 32; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + a += 32; q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#if defined __loongarch_asx + + __m256 acc = (__m256)__lasx_xvldi(0); + __m128 acc_m = (__m128)__lsx_vldi(0); + + for (int i = 0; i < nb; ++i) { + + const uint8_t * GGML_RESTRICT q5 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]); + const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128); + const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0); + + const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0); + const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1)); + const __m128i prod = lsx_madd_h(mins128, q8s); + acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m); + + const __m256i scales = lasx_insertf128(scales128, scales128); + + const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0); + + __m256i sumi = __lasx_xvldi(0); + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_0 = lasx_xvrepl128vei_h(scales, 2 * j + 0); + const __m256i scale_1 = lasx_xvrepl128vei_h(scales, 2 * j + 1); + + const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32; + + const __m256i q5l_0 = __lasx_xvandi_b(q5bits, 0xf); + const __m256i q5l_1 = __lasx_xvsrli_b(q5bits, 4); + const __m256i q5h_0 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 0), 0), 0xef); + const __m256i q5h_1 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 1), 0), 0xef); + const __m256i q5_0 = __lasx_xvor_v(q5l_0, q5h_0); + const __m256i q5_1 = __lasx_xvor_v(q5l_1, q5h_1); + + const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + + __m256i p16_0 = lasx_madd_h_b(q5_0, q8_0); + __m256i p16_1 = lasx_madd_h_b(q5_1, q8_1); + + p16_0 = lasx_madd_h(scale_0, p16_0); + p16_1 = lasx_madd_h(scale_1, p16_1); + + sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1)); + + } + + __m256 vd = __lasx_xvreplfr2vr_s(d); + acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc); + + } + + acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 8)); + acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 4)); + + *s = hsum_float_8(acc) + ((v4f32)acc_m)[0]; + +#else + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const uint8_t * GGML_RESTRICT hm = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined __loongarch_asx + + const __m256i m32s = __lasx_xvreplgr2vr_b(32); + + __m256 acc = (__m256)__lasx_xvldi(0); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * GGML_RESTRICT q4 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const __m128i scales128 = __lsx_vld((const __m128i*)x[i].scales, 0); + const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; + const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask)); + + __m256i sumi = __lasx_xvldi(0); + + for (int j = 0; j < QK_K/128; ++j) { + + const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; + const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; + const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32; + + const __m256i q4h_0 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3), 4); + const __m256i q4h_1 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3 << 2), 2); + const __m256i q4h_2 = __lasx_xvandi_b(q4bitsH, 3 << 4); + const __m256i q4h_3 = __lasx_xvsrli_b(__lasx_xvandi_b(q4bitsH, 3 << 6), 2); + + const __m256i q4_0 = __lasx_xvor_v(__lasx_xvandi_b(q4bits1, 0xf), q4h_0); + const __m256i q4_1 = __lasx_xvor_v(__lasx_xvandi_b(q4bits2, 0xf), q4h_1); + const __m256i q4_2 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits1, 4), q4h_2); + const __m256i q4_3 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits2, 4), q4h_3); + + const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + + __m256i p16_0 = lasx_madd_h_b(__lasx_xvsub_b(q4_0, m32s), q8_0); + __m256i p16_1 = lasx_madd_h_b(__lasx_xvsub_b(q4_1, m32s), q8_1); + __m256i p16_2 = lasx_madd_h_b(__lasx_xvsub_b(q4_2, m32s), q8_2); + __m256i p16_3 = lasx_madd_h_b(__lasx_xvsub_b(q4_3, m32s), q8_3); + + p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0); + p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1); + p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2); + p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3); + + sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1)); + sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3)); + } + + acc = __lasx_xvfmadd_s((__m256)__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc); + } + + *s = hsum_float_8(acc); + +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q4 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + a += 128; + q4 += 64; + qh += 32; + } + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/16; ++j) { + int scale = x[i].scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +#if defined(__loongarch_asx) +static const int8_t keven_signs_q2xs[1024] = { + 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, + 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, + 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1, + 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1, + 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, + 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1, + 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1, + 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, + 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1, + 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1, + 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1, + 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1, + 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1, + 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1, + 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1, + 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1, + 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1, + 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1, + 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1, + 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1, + 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1, + 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1, + 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, + 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1, + 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1, + 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1, + 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, + 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1, + 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1, + 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1, + 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1, + 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, +}; +#endif + +void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__loongarch_asx) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + __m256 accumf = (__m256)__lasx_xvldi(0); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * GGML_RESTRICT q2 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + __m256i sumi1 = __lasx_xvldi(0); + __m256i sumi2 = __lasx_xvldi(0); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; + + const __m256i q2_1 = lasx_set_d(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); + const __m256i q2_2 = lasx_set_d(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); + const __m256i s2_1 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], + signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m256i s2_2 = lasx_set_d(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127], + signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); + const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1); + const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2); + const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); + const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); + const uint16_t ls1 = aux32[1] >> 28; + const uint16_t ls2 = aux32[3] >> 28; + const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1)); + const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1)); + sumi1 = __lasx_xvadd_w(sumi1, p1); + sumi2 = __lasx_xvadd_w(sumi2, p2); + } + + accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); + } + + *s = 0.125f * hsum_float_8(accumf); + +#else + + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * GGML_RESTRICT q2 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(aux32, q2, 2*sizeof(uint32_t)); + q2 += 4; + const uint32_t ls = 2*(aux32[1] >> 28) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); + const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls; + } + sumf += d * bsum; + } + *s = 0.125f * sumf; +#endif +} + +void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__loongarch_asx) + + const __m256i mone = __lasx_xvreplgr2vr_b(1); + static const char block_sign_shuffle_mask_1[32] = { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, + 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, + }; + static const char block_sign_shuffle_mask_2[32] = { + 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, + 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, + }; + static const uint8_t bit_selector_mask_bytes[32] = { + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m256i bit_selector_mask = __lasx_xvld((const __m256i*)bit_selector_mask_bytes, 0); + const __m256i block_sign_shuffle_1 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_1, 0); + const __m256i block_sign_shuffle_2 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_2, 0); + + static const uint8_t k_bit_helper[32] = { + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + }; + const __m256i bit_helper = __lasx_xvld((const __m256i*)k_bit_helper, 0); + const __m256i m511 = __lasx_xvreplgr2vr_h(511); + const __m128i m4 = __lsx_vreplgr2vr_b(0xf); + const __m128i m1 = __lsx_vreplgr2vr_b(1); + + uint64_t aux64; + + // somewhat hacky, but gives a significant boost in performance + __m256i aux_gindex; + const uint16_t * gindex = (const uint16_t *)&aux_gindex; + + __m256 accumf = (__m256)__lasx_xvldi(0); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * GGML_RESTRICT q2 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + memcpy(&aux64, x[i].scales, 8); + __m128i stmp = __lsx_vreplgr2vr_d(aux64); + stmp = __lsx_vilvl_b( __lsx_vand_v(__lsx_vsrli_h(stmp, 4), m4), __lsx_vand_v(stmp, m4)); + const __m128i scales = __lsx_vadd_b(__lsx_vslli_h(stmp, 1), m1); + + __m256i sumi1 = __lasx_xvldi(0); + __m256i sumi2 = __lasx_xvldi(0); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) { + + const __m256i q2_data = __lasx_xvld((const __m256i*)q2, 0); q2 += 16; + aux_gindex = __lasx_xvand_v(q2_data, m511); + + const __m256i partial_sign_bits = __lasx_xvsrli_h(q2_data, 9); + const __m256i partial_sign_bits_upper = __lasx_xvsrli_h(q2_data, 13); + const __m256i partial_sign_bits_for_counting = __lasx_xvxor_v(partial_sign_bits, partial_sign_bits_upper); + + const __m256i odd_bits = lasx_shuffle_b(bit_helper, partial_sign_bits_for_counting); + const __m256i full_sign_bits = __lasx_xvor_v(partial_sign_bits, odd_bits); + + const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_3 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_4 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + + const __m256i q2_1 = lasx_set_d(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]], + iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]); + const __m256i q2_2 = lasx_set_d(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]], + iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]); + const __m256i q2_3 = lasx_set_d(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]], + iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]); + const __m256i q2_4 = lasx_set_d(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]], + iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]); + + const __m128i full_signs_l = lasx_extracti128(full_sign_bits, 0); + const __m128i full_signs_h = lasx_extracti128(full_sign_bits, 1); + const __m256i full_signs_1 = lasx_insertf128(full_signs_l, full_signs_l); + const __m256i full_signs_2 = lasx_insertf128(full_signs_h, full_signs_h); + + __m256i signs; + signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_1); + signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_1 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_1); + + signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_2); + signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_2 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_2); + + signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_1); + signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_3 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_3); + + signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_2); + signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_4 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_4); + + const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); + const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); + const __m256i dot3 = lasx_maddubs_h(q2_3, q8s_3); + const __m256i dot4 = lasx_maddubs_h(q2_4, q8s_4); + + const __m256i sc1 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+0))); + const __m256i sc2 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+1))); + const __m256i sc3 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+2))); + const __m256i sc4 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+3))); + + sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot1, sc1)); + sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot2, sc2)); + sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot3, sc3)); + sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot4, sc4)); + } + + accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + +#else + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * GGML_RESTRICT q2 = x[i].qs; + const uint8_t * GGML_RESTRICT sc = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1; + const uint16_t ls2 = 2*(sc[ib32] >> 4) + 1; + int32_t sumi = 0; + for (int l = 0; l < 2; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); + const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls1; + sumi = 0; + for (int l = 2; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); + const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls2; + q2 += 4; + } + sumf += d * bsum; + } + *s = 0.125f * sumf; +#endif +} + +void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__loongarch_asx) + + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + + const __m128i m4 = __lsx_vreplgr2vr_b(0xf); + const __m128i m1 = __lsx_vreplgr2vr_b(1); + + const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0); + const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0); + uint64_t aux64; + + __m256 accumf = (__m256)__lasx_xvldi(0); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8); + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + __m128i tmp1; + memcpy(&aux64, x[i].scales, 8); + tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64, 0); + tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64 >> 4, 1); + const __m128i scales8 = __lsx_vadd_b(__lsx_vslli_h(__lsx_vand_v(tmp1, m4), 1), m1); + const __m256i scales16 = lasx_ext8_16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15 + + __m256i sumi1 = __lasx_xvldi(0); + __m256i sumi2 = __lasx_xvldi(0); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q2_1 = lasx_set_d(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)], + iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)], + iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)], + iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]); + const __m256i q2_2 = lasx_set_d(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)], + iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)], + iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)], + iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]); + qs += 8; + + __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | ((uint32_t) signs[1] << 16)); + aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); + const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2); + const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1); + + aux256 = __lasx_xvreplgr2vr_w(signs[2] | ((uint32_t) signs[3] << 16)); + aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); + const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2); + const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2); + + signs += 4; + + const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1 + const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3 + + const __m256i p1 = lasx_madd_h(dot1, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+0))); + const __m256i p2 = lasx_madd_h(dot2, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+1))); + sumi1 = __lasx_xvadd_w(sumi1, p1); + sumi2 = __lasx_xvadd_w(sumi2, p2); + } + + accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); + } + + *s = 0.125f * hsum_float_8(accumf); + +#else + + float sumf = 0; + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint8_t * signs = qs + QK_K/8; + + int bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + int ls1 = 1 + 2*(x[i].scales[ib32] & 0xf); + int ls2 = 1 + 2*(x[i].scales[ib32] >> 4); + int sumi1 = 0, sumi2 = 0; + for (int l = 0; l < 2; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); + for (int j = 0; j < 8; ++j) { + sumi1 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + for (int l = 2; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); + for (int j = 0; j < 8; ++j) { + sumi2 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += ls1 * sumi1 + ls2 * sumi2; + qs += 4; + signs += 4; + } + + sumf += d * bsum; + } + + *s = 0.125f * sumf; + +#endif + +} + +void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq3_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__loongarch_asx) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[2]; + + __m256 accumf = (__m256)__lasx_xvldi(0); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + __m256i sumi1 = __lasx_xvldi(0); + __m256i sumi2 = __lasx_xvldi(0); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q2_1 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], + iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + q3 += 8; + const __m256i q2_2 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], + iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + q3 += 8; + memcpy(aux32, gas, 8); gas += 8; + + const __m256i s2_1 = lasx_set_d(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127], + signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]); + const __m256i s2_2 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], + signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1); + const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2); + const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); + const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); + const uint16_t ls1 = aux32[0] >> 28; + const uint16_t ls2 = aux32[1] >> 28; + + const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1)); + const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1)); + sumi1 = __lasx_xvadd_w(sumi1, p1); + sumi2 = __lasx_xvadd_w(sumi2, p2); + } + + accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); + } + + *s = 0.25f * hsum_float_8(accumf); + +#else + + uint32_t aux32; + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t); + const uint32_t ls = 2*(aux32 >> 28) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]); + const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]); + const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127]; + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + q3 += 8; + bsum += sumi * ls; + } + sumf += d * bsum; + } + *s = 0.25f * sumf; +#endif +} + +void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq3_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__loongarch_asx) + + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0); + const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0); + + __m256i idx_shift = lasx_set_w(1, 2, 3, 4, 5, 6, 7, 8); + const __m256i idx_mask = __lasx_xvreplgr2vr_w(256); + + typedef union { + __m256i vec[2]; + uint32_t index[16]; + } index_t; + + index_t idx; + + __m256 accumf = (__m256)__lasx_xvldi(0); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + __m256i sumi1 = __lasx_xvldi(0); + __m256i sumi2 = __lasx_xvldi(0); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i idx_l = lasx_extu8_16(__lsx_vld(qs, 0)); qs += 16; + idx.vec[0] = __lasx_xvreplgr2vr_w(qh[ib32+0]); + idx.vec[1] = __lasx_xvreplgr2vr_w(qh[ib32+1]); + idx.vec[0] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[0], idx_shift), idx_mask); + idx.vec[1] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[1], idx_shift), idx_mask); + idx.vec[0] = __lasx_xvor_v(idx.vec[0], lasx_ext16_32(lasx_extracti128(idx_l, 0))); + idx.vec[1] = __lasx_xvor_v(idx.vec[1], lasx_ext16_32(lasx_extracti128(idx_l, 1))); + + // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange. + //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4); + //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4); + const __m256i q2_1 = lasx_set_w( + iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]], + iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]] + ); + const __m256i q2_2 = lasx_set_w( + iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]], + iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]] + ); + + __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | (signs[1] << 16)); + aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); + const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2); + const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1); + + aux256 = __lasx_xvreplgr2vr_w(signs[2] | (signs[3] << 16)); + aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); + const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2); + const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2); + + signs += 4; + + const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); + const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); + const uint16_t ls1 = x[i].scales[ib32/2] & 0xf; + const uint16_t ls2 = x[i].scales[ib32/2] >> 4; + const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1)); + const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1)); + sumi1 = __lasx_xvadd_w(sumi1, p1); + sumi2 = __lasx_xvadd_w(sumi2, p2); + } + + accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); + } + + *s = hsum_float_8(accumf); + +#else + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT signs = x[i].signs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1; + const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + qs += 8; + signs += 4; + bsum += sumi * ls1; + sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + qs += 8; + signs += 4; + bsum += sumi * ls2; + } + sumf += d * bsum; + } + *s = sumf; +#endif +} + +#if defined(__loongarch_asx) +static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { + const __m256i a = __lasx_xvmulwev_h_b(x, y); + const __m256i b = __lasx_xvmulwod_h_b(x, y); + return __lasx_xvadd_h(a, b); +} +#endif + +void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__loongarch_asx) + + __m256 accum = (__m256)__lasx_xvldi(0); + float accum1 = 0; + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + __m256i sumi = __lasx_xvldi(0); + int sumi1 = 0; + for (int ib = 0; ib < QK_K/32; ib += 2) { + __m256i q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)], 0); + q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], 1); + q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)], 2); + q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], 3); + + __m256i q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)], 0); + q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], 1); + q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)], 2); + q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], 3); + + qs += 8; + const __m256i q8b_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + const __m256i q8b_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; + + const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1); + const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2); + const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; + const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; + + __m256i tmp1, tmp5, tmp6; + tmp1 = __lasx_xvreplgr2vr_h(ls1); + tmp5 = __lasx_xvmulwev_w_h(dot1, tmp1); + tmp6 = __lasx_xvmulwod_w_h(dot1, tmp1); + const __m256i p1 = __lasx_xvadd_w(tmp5, tmp6); + + tmp1 = __lasx_xvreplgr2vr_h(ls2); + tmp5 = __lasx_xvmulwev_w_h(dot2, tmp1); + tmp6 = __lasx_xvmulwod_w_h(dot2, tmp1); + const __m256i p2 = __lasx_xvadd_w(tmp5, tmp6); + + sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p1, p2)); + sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1 + + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2; + } + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), accum); + accum1 += d * sumi1; + } + + *s = hsum_float_8(accum) + IQ1S_DELTA * accum1; + +#else + + float sumf = 0; + for (int i = 0; i < nb; i++) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + int sumi = 0, sumi1 = 0; + for (int ib = 0; ib < QK_K/32; ++ib) { + const int ls = 2*((qh[ib] >> 12) & 7) + 1; + const int delta = qh[ib] & 0x8000 ? -1 : 1; + int lsum = 0; + for (int l = 0; l < 4; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8))); + for (int j = 0; j < 8; ++j) { + lsum += q8[j] * grid[j]; + } + q8 += 8; + } + sumi += ls * lsum; + sumi1 += ls * delta * (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]); + qs += 4; + } + + sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); + } + + *s = sumf; + +#endif +} + +void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK4_NL == 0); + static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same"); + + const block_iq4_nl * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + const int nb = n / QK4_NL; + + int ib = 0; + float sumf = 0; + +#if defined (__loongarch_asx) + + const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0); + const __m128i m4b = __lsx_vreplgr2vr_b(0x0f); + const __m256i mone = __lasx_xvreplgr2vr_h(1); + + __m256 accum1 = (__m256)__lasx_xvldi(0); + __m256 accum2 = (__m256)__lasx_xvldi(0); + for (; ib + 1 < nb; ib += 2) { + const __m128i q4bits_1 = __lsx_vld((const __m128i*)x[ib + 0].qs, 0); + const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[ib + 1].qs, 0); + const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[ib + 0].qs, 0); + const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[ib + 1].qs, 0); + const __m256i q4b_1 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b)), + lsx_shuffle_b(values128, __lsx_vand_v(q4bits_1, m4b))); + const __m256i q4b_2 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b)), + lsx_shuffle_b(values128, __lsx_vand_v(q4bits_2, m4b))); + const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); + const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); + const __m256i p_1 = lasx_madd_h(p16_1, mone); + const __m256i p_2 = lasx_madd_h(p16_2, mone); + accum1 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)), + __lasx_xvffint_s_w(p_1), accum1); + accum2 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)), + __lasx_xvffint_s_w(p_2), accum2); + } + + sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2)); + +#endif + for (; ib < nb; ++ib) { + const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d); + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < QK4_NL/2; ++j) { + sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf]; + sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4]; + } + sumf += d * (sumi1 + sumi2); + } + *s = sumf; +} + +void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); + + const block_iq4_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__loongarch_asx) + + const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0); + + __m256 accum = (__m256)__lasx_xvldi(0); + + for (int ibl = 0; ibl < nb; ++ibl) { + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + uint16_t sh = x[ibl].scales_h; + __m256i sumi1 = __lasx_xvldi(0); + __m256i sumi2 = __lasx_xvldi(0); + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16; + const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16; + const __m256i q8b_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q8b_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; + const __m256i q4b_1 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_1, 4)), + __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_1, 0xf))); + const __m256i q4b_2 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_2, 4)), + __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_2, 0xf))); + const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); + const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); + const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32; + sh >>= 4; + const __m256i p_1 = lasx_madd_h(p16_1, __lasx_xvreplgr2vr_h(ls1)); + const __m256i p_2 = lasx_madd_h(p16_2, __lasx_xvreplgr2vr_h(ls2)); + sumi1 = __lasx_xvadd_w(p_1, sumi1); + sumi2 = __lasx_xvadd_w(p_2, sumi2); + } + accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d), + __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accum); + } + + *s = hsum_float_8(accum); + +#else + float sumf = 0; + for (int ibl = 0; ibl < nb; ++ibl) { + const float d4d8 = GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d; + uint16_t h = x[ibl].scales_h; + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + for (int ib = 0; ib < QK_K/32; ib += 2) { + const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30); + const uint8_t ls2 = (x[ibl].scales_l[ib/2] >> 4) | ((h << 2) & 0x30); + h >>= 4; + const float d1 = d4d8*(ls1 - 32); + const float d2 = d4d8*(ls2 - 32); + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 16; ++j) { + sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; + sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; + } + sumf += d1 * (sumi1 + sumi2); + qs += 16; + q8 += 32; + sumi1 = sumi2 = 0; + for (int j = 0; j < 16; ++j) { + sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; + sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; + } + sumf += d2 * (sumi1 + sumi2); + qs += 16; + q8 += 32; + } + } + *s = sumf; +#endif +} + diff --git a/ggml/src/ggml-cpu/arch/powerpc/quants.c b/ggml/src/ggml-cpu/arch/powerpc/quants.c new file mode 100644 index 000000000..ce4e47a86 --- /dev/null +++ b/ggml/src/ggml-cpu/arch/powerpc/quants.c @@ -0,0 +1,2731 @@ +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" +#include "ggml-quants.h" +#include "ggml-impl.h" +#include "ggml-cpu.h" + +#include "../../quants.h" +#include "../../ggml-cpu-impl.h" + +#include +#include +#include +#include +#include // for qsort +#include // for GGML_ASSERT + +#define GROUP_MAX_EPS 1e-15f +#define GROUP_MAX_EPS_IQ3_XXS 1e-8f +#define GROUP_MAX_EPS_IQ2_S 1e-8f +#define GROUP_MAX_EPS_IQ1_M 1e-7f +#define GROUP_MAX_EPS_IQ1_S 1e-12f + +#define UNUSED GGML_UNUSED + +#if defined(__POWER9_VECTOR__) +#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s +#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) +#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s) +#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s) +#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s) +#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s) +#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s) +#define B8(c,s ) B7(c,s, c), B7(c,s, s) + +// precomputed tables for expanding 8bits to 8 bytes: +static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4 +static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 +#endif + +void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0 * GGML_RESTRICT y = vy; + +#if defined(__POWER9_VECTOR__) + for (int i = 0; i < nb; i++) { + vector float srcv [8]; + vector float asrcv[8]; + vector float amaxv[8]; + vector signed int vi[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(vec_extract(amaxv[0], 0), + vec_extract(amaxv[0], 1)), + MAX(vec_extract(amaxv[0], 2), + vec_extract(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + const vector float vid = vec_splats(id); + + y[i].d = GGML_FP32_TO_FP16(d); + + for (int j = 0; j < 8; j++) { + const vector float v = vec_round(vec_mul(srcv[j], vid)); + vi[j] = vec_cts(v, 0); + } + vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])), 0, &y[i].qs[0]); + vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]); + } +#else + GGML_UNUSED(nb); + // scalar + quantize_row_q8_0_ref(x, y, k); +#endif +} + +void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % QK8_1 == 0); + const int nb = k / QK8_1; + + block_q8_1 * GGML_RESTRICT y = vy; + +#if defined(__POWER9_VECTOR__) + for (int i = 0; i < nb; i++) { + vector float srcv [8]; + vector float asrcv[8]; + vector float amaxv[8]; + vector signed int vi[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(vec_extract(amaxv[0], 0), + vec_extract(amaxv[0], 1)), + MAX(vec_extract(amaxv[0], 2), + vec_extract(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + const vector float vid = vec_splats(id); + + y[i].d = GGML_FP32_TO_FP16(d); + + vector int accv = vec_splats(0); + + for (int j = 0; j < 8; j++) { + const vector float v = vec_round(vec_mul(srcv[j], vid)); + vi[j] = vec_cts(v, 0); + + accv = vec_add(accv, vi[j]); + } + vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])), 0, &y[i].qs[0]); + vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]); + + accv = vec_add(accv, vec_sld(accv, accv, 4)); + accv = vec_add(accv, vec_sld(accv, accv, 8)); + y[i].s = GGML_FP32_TO_FP16(d * vec_extract(accv, 0)); + } + +#else + GGML_UNUSED(nb); + // scalar + quantize_row_q8_1_ref(x, y, k); +#endif +} + + +//===================================== Dot products ================================= + +void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + int ib = 0; + float sumf = 0; + +#if defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed int v0 = vec_splats((int32_t)0); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + const vector signed char v8 = vec_splats((signed char)0x8); + + vector float vsumf0 = vec_splats(0.0f); + +#pragma GCC unroll 8 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); + vector float vd = vec_mul(vxd, vyd); + + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl(16, y[ib].qs); + + vector signed char q4x0 = vec_and(qxs, lowMask); + vector signed char q4x1 = vec_sr(qxs, v4); + + q4x0 = vec_sub(q4x0, v8); + q4x1 = vec_sub(q4x1, v8); + + vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1)); + + vector signed int vsumi0 = v0; + + vsumi0 = vec_sum4s(qv0, vsumi0); + vsumi0 = vec_sum4s(qv1, vsumi0); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + sumf = vec_extract(vsumf0, 0); + +#endif + for (; ib < nb; ++ib) { + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[ib].qs[j] & 0x0F) - 8; + const int v1 = (x[ib].qs[j] >> 4) - 8; + + sumi0 += (v0 * y[ib].qs[j]); + sumi1 += (v1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d); + } + + *s = sumf; +} + +void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_1 * GGML_RESTRICT x = vx; + const block_q8_1 * GGML_RESTRICT y = vy; + + int ib = 0; + float sumf = 0; + +#if defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed int v0 = vec_splats((int32_t)0); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + +#pragma GCC unroll 4 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); + vector float vd = vec_mul(vxd, vyd); + + vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[ib].m)); + vector float vys = {GGML_FP16_TO_FP32(y[ib].s), 0.0f, 0.0f, 0.0f}; + vsumf0 = vec_madd(vxmin, vys, vsumf0); + + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl(16, y[ib].qs); + + vector unsigned char q4x0 = (vector unsigned char)vec_and(qxs, lowMask); + vector unsigned char q4x1 = (vector unsigned char)vec_sr(qxs, v4); + + vector signed int vsumi0 = v0; + + vsumi0 = vec_msum(q8y0, q4x0, vsumi0); + vsumi0 = vec_msum(q8y1, q4x1, vsumi0); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + sumf = vec_extract(vsumf0, 0); + +#endif + for (; ib < nb; ++ib) { + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[ib].qs[j] & 0x0F); + const int v1 = (x[ib].qs[j] >> 4); + + sumi0 += (v0 * y[ib].qs[j]); + sumi1 += (v1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); + } + + *s = sumf; +} + +void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + int ib = 0; + float sumf = 0; + + assert(n % qk == 0); + assert(qk == QK5_0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + +#if defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char v4 = vec_splats((unsigned char)4); + + vector float vsumf0 = vec_splats(0.0f); + +#pragma GCC unroll 4 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); + vector float vd = vec_mul(vxd, vyd); + + vector signed long long aux64x2_0 = {(uint64_t)(table_b2b_1[x[ib].qh[0]]), (uint64_t)(table_b2b_1[x[ib].qh[1]])}; + vector signed long long aux64x2_1 = {(uint64_t)(table_b2b_1[x[ib].qh[2]]), (uint64_t)(table_b2b_1[x[ib].qh[3]])}; + + vector signed char qh0 = (vector signed char)aux64x2_0; + vector signed char qh1 = (vector signed char)aux64x2_1; + + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); + + vector signed char q5x0 = vec_sub(vec_and (qxs, lowMask), qh0); + vector signed char q5x1 = vec_sub(vec_sr(qxs, v4), qh1); + + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl( 16, y[ib].qs); + + vector signed short qv0 = vec_add(vec_mule(q5x0, q8y0), vec_mulo(q5x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q5x1, q8y1), vec_mulo(q5x1, q8y1)); + + qv0 = vec_add(qv0, qv1); + + vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0)); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + sumf = vec_extract(vsumf0, 0); + +#endif + for (; ib < nb; ++ib) { + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); + + const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16); + const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16); + + sumi0 += (x0 * y[ib].qs[j]); + sumi1 += (x1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi; + } + + *s = sumf; +} + +void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; + + int ib = 0; + float sumf = 0; + + assert(n % qk == 0); + assert(qk == QK5_1); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_1 * GGML_RESTRICT x = vx; + const block_q8_1 * GGML_RESTRICT y = vy; + +#if defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed int v0 = vec_splats((int32_t)0); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + +#pragma GCC unroll 4 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); + vector float vd = vec_mul(vxd, vyd); + + vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[ib].m)); + vector float vys = {GGML_FP16_TO_FP32(y[ib].s), 0.f, 0.f, 0.f}; + vsumf0 = vec_madd(vxmin, vys, vsumf0); + + vector unsigned long long aux64x2_0 = {(uint64_t)(table_b2b_0[x[ib].qh[0]]), (uint64_t)(table_b2b_0[x[ib].qh[1]])}; + vector unsigned long long aux64x2_1 = {(uint64_t)(table_b2b_0[x[ib].qh[2]]), (uint64_t)(table_b2b_0[x[ib].qh[3]])}; + + vector signed char qh0 = (vector signed char)aux64x2_0; + vector signed char qh1 = (vector signed char)aux64x2_1; + + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); + + vector unsigned char q5x0 = (vector unsigned char)vec_or(vec_and(qxs, lowMask), qh0); + vector unsigned char q5x1 = (vector unsigned char)vec_or(vec_sr(qxs, v4), qh1); + + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl( 16, y[ib].qs); + + vector signed int vsumi0 = v0; + + vsumi0 = vec_msum(q8y0, q5x0, vsumi0); + vsumi0 = vec_msum(q8y1, q5x1, vsumi0); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + sumf = vec_extract(vsumf0, 0); + +#endif + for (; ib < nb; ++ib) { + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; + + const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0; + const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1; + + sumi0 += (x0 * y[ib].qs[j]); + sumi1 += (x1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); + } + + *s = sumf; +} + +void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q8_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + int ib = 0; + float sumf = 0; + +#if defined(__POWER9_VECTOR__) + const vector signed int v0 = vec_splats((int32_t)0); + vector float vsumf0 = vec_splats(0.0f); + +#pragma GCC unroll 8 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); + vector float vd = vec_mul(vxd, vyd); + + vector signed char q8x0 = vec_xl( 0, x[ib].qs); + vector signed char q8x1 = vec_xl(16, x[ib].qs); + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl(16, y[ib].qs); + + vector signed short qv0 = vec_mule(q8x0, q8y0); + vector signed short qv1 = vec_mulo(q8x0, q8y0); + vector signed short qv2 = vec_mule(q8x1, q8y1); + vector signed short qv3 = vec_mulo(q8x1, q8y1); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + + vsumi0 = vec_sum4s(qv0, vsumi0); + vsumi1 = vec_sum4s(qv1, vsumi1); + vsumi0 = vec_sum4s(qv2, vsumi0); + vsumi1 = vec_sum4s(qv3, vsumi1); + + vsumi0 = vec_add(vsumi0, vsumi1); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + } + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + sumf = vec_extract(vsumf0, 0); + +#endif + for (; ib < nb; ++ib) { + int sumi = 0; + + for (int j = 0; j < qk; j++) { + sumi += x[ib].qs[j]*y[ib].qs[j]; + } + + sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)); + } + + *s = sumf; +} + +void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q2_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0x3); + const vector signed char lowScaleMask = vec_splats((signed char)0xF); + const vector int v0 = vec_splats((int32_t)0); + const vector unsigned char v2 = vec_splats((unsigned char)0x2); + const vector unsigned char v6 = vec_splats((unsigned char)0x6); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin)); + vector float vdmin = vec_mul(vxmin, vyd); + + vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); + vector signed short q8ysums1 = vec_xl(16, y[i].bsums); + + vector signed char q2xmins = (vector signed char)vec_xl( 0, x[i].scales); + vector signed char vscales = vec_and(q2xmins, lowScaleMask); + + q2xmins = vec_sr(q2xmins, v4); + vector signed short q2xmins0 = vec_unpackh(q2xmins); + vector signed short q2xmins1 = vec_unpackl(q2xmins); + + vector signed int prod0 = vec_mule(q2xmins0, q8ysums0); + vector signed int prod1 = vec_mulo(q2xmins0, q8ysums0); + vector signed int prod2 = vec_mule(q2xmins1, q8ysums1); + vector signed int prod3 = vec_mulo(q2xmins1, q8ysums1); + + vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0); + vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1); + vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); + vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + vector signed int vsumi4 = v0; + vector signed int vsumi5 = v0; + vector signed int vsumi6 = v0; + vector signed int vsumi7 = v0; + + const uint8_t * GGML_RESTRICT q2 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + for (int j = 0; j < QK_K/128; ++j) { + __builtin_prefetch(q2, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q2); + vector signed char qxs1 = (vector signed char)vec_xl(16, q2); + q2 += 32; + + vector unsigned char q2x00 = (vector unsigned char)vec_and(qxs0, lowMask); + vector unsigned char q2x01 = (vector unsigned char)vec_and(vec_sr(qxs0, v2), lowMask); + vector unsigned char q2x02 = (vector unsigned char)vec_and(vec_sr(qxs0, v4), lowMask); + vector unsigned char q2x03 = (vector unsigned char)vec_and(vec_sr(qxs0, v6), lowMask); + vector unsigned char q2x10 = (vector unsigned char)vec_and(qxs1, lowMask); + vector unsigned char q2x11 = (vector unsigned char)vec_and(vec_sr(qxs1, v2), lowMask); + vector unsigned char q2x12 = (vector unsigned char)vec_and(vec_sr(qxs1, v4), lowMask); + vector unsigned char q2x13 = (vector unsigned char)vec_and(vec_sr(qxs1, v6), lowMask); + + vector signed char q8y00 = vec_xl( 0, q8); + vector signed char q8y10 = vec_xl( 16, q8); + vector signed char q8y01 = vec_xl( 32, q8); + vector signed char q8y11 = vec_xl( 48, q8); + vector signed char q8y02 = vec_xl( 64, q8); + vector signed char q8y12 = vec_xl( 80, q8); + vector signed char q8y03 = vec_xl( 96, q8); + vector signed char q8y13 = vec_xl(112, q8); + q8 += 128; + + vector signed int qv0 = vec_msum(q8y00, q2x00, v0); + vector signed int qv1 = vec_msum(q8y01, q2x01, v0); + vector signed int qv2 = vec_msum(q8y02, q2x02, v0); + vector signed int qv3 = vec_msum(q8y03, q2x03, v0); + vector signed int qv4 = vec_msum(q8y10, q2x10, v0); + vector signed int qv5 = vec_msum(q8y11, q2x11, v0); + vector signed int qv6 = vec_msum(q8y12, q2x12, v0); + vector signed int qv7 = vec_msum(q8y13, q2x13, v0); + + vector signed short vscales_07 = vec_unpackh(vscales); + vector signed int vscales_03 = vec_unpackh(vscales_07); + vector signed int vscales_47 = vec_unpackl(vscales_07); + vector signed int vs0 = vec_splat(vscales_03, 0); + vector signed int vs1 = vec_splat(vscales_03, 1); + vector signed int vs2 = vec_splat(vscales_03, 2); + vector signed int vs3 = vec_splat(vscales_03, 3); + vector signed int vs4 = vec_splat(vscales_47, 0); + vector signed int vs5 = vec_splat(vscales_47, 1); + vector signed int vs6 = vec_splat(vscales_47, 2); + vector signed int vs7 = vec_splat(vscales_47, 3); + vscales = vec_sld(vscales, vscales, 8); + + vsumi0 = vec_add(vec_mul(qv0, vs0), vsumi0); + vsumi1 = vec_add(vec_mul(qv1, vs2), vsumi1); + vsumi2 = vec_add(vec_mul(qv2, vs4), vsumi2); + vsumi3 = vec_add(vec_mul(qv3, vs6), vsumi3); + vsumi4 = vec_add(vec_mul(qv4, vs1), vsumi4); + vsumi5 = vec_add(vec_mul(qv5, vs3), vsumi5); + vsumi6 = vec_add(vec_mul(qv6, vs5), vsumi6); + vsumi7 = vec_add(vec_mul(qv7, vs7), vsumi7); + } + + vsumi0 = vec_add(vsumi0, vsumi4); + vsumi1 = vec_add(vsumi1, vsumi5); + vsumi2 = vec_add(vsumi2, vsumi6); + vsumi3 = vec_add(vsumi3, vsumi7); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#else + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + int summs = 0; + for (int j = 0; j < 16; ++j) { + summs += y[i].bsums[j] * (sc[j] >> 4); + } + + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + int isum = 0; + int is = 0; + int d; + for (int k = 0; k < QK_K/128; ++k) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + d = sc[is++] & 0xF; + int isuml = 0; + for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + d = sc[is++] & 0xF; + isuml = 0; + for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + shift += 2; + q8 += 32; + } + q2 += 32; + } + sumf += dall * isum - dmin * summs; + } + *s = sumf; +#endif +} + +void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0x3); + const vector signed char lowMask1 = vec_splats((int8_t)0xf); + const vector signed char lowMask2 = vec_splats((int8_t)0x30); + const vector int v0 = vec_splats((int32_t)0); + const vector signed char v1 = vec_splats((signed char)0x1); + const vector unsigned char v2 = vec_splats((unsigned char)0x2); + const vector unsigned char v3 = vec_splats((unsigned char)0x3); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + const vector unsigned char v6 = vec_splats((unsigned char)0x6); + const vector signed char off = vec_splats((signed char)0x20); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + UNUSED(kmask1); + UNUSED(kmask2); + + vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); + vector signed char u1 = vec_and(u0, lowMask1); + vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); + vector signed char u3 = (vector signed char)vec_mergeh((vector signed int)u2, (vector signed int)vec_sr(u2, v2)); + vector signed char u30 = vec_sl(vec_and(u3, lowMask), v4); + vector signed char u31 = vec_and(u3, lowMask2); + + u1 = vec_or(u1, u30); + u2 = vec_or(vec_sr(u0, v4), u31); + + vector signed char vscales = (vector signed char)vec_mergeh((vector signed long long)u1, (vector signed long long)u2); + vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].hmask); + vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].hmask); + + vscales = vec_sub(vscales, off); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + vector signed int vsumi4 = v0; + vector signed int vsumi5 = v0; + vector signed int vsumi6 = v0; + vector signed int vsumi7 = v0; + + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + for (int j = 0; j < QK_K/128; ++j) { + __builtin_prefetch(q3, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q3); + vector signed char qxs1 = (vector signed char)vec_xl(16, q3); + q3 += 32; + + //the low 2 bits + vector signed char qxs00 = vec_and(qxs0, lowMask); + vector signed char qxs01 = vec_and(vec_sr(qxs0, v2), lowMask); + vector signed char qxs02 = vec_and(vec_sr(qxs0, v4), lowMask); + vector signed char qxs03 = vec_and(vec_sr(qxs0, v6), lowMask); + vector signed char qxs10 = vec_and(qxs1, lowMask); + vector signed char qxs11 = vec_and(vec_sr(qxs1, v2), lowMask); + vector signed char qxs12 = vec_and(vec_sr(qxs1, v4), lowMask); + vector signed char qxs13 = vec_and(vec_sr(qxs1, v6), lowMask); + + //the 3rd bit + vector signed char qxh00 = vec_sl(vec_andc(v1, qxhs0), v2); + vector signed char qxh01 = vec_sl(vec_andc(v1, vec_sr(qxhs0, (vector unsigned char)v1)), v2); + vector signed char qxh02 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v2)), v2); + vector signed char qxh03 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v3)), v2); + vector signed char qxh10 = vec_sl(vec_andc(v1, qxhs1), v2); + vector signed char qxh11 = vec_sl(vec_andc(v1, vec_sr(qxhs1, (vector unsigned char)v1)), v2); + vector signed char qxh12 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v2)), v2); + vector signed char qxh13 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v3)), v2); + qxhs0 = vec_sr(qxhs0, v4); + qxhs1 = vec_sr(qxhs1, v4); + + vector signed char q3x00 = vec_sub(qxs00, qxh00); + vector signed char q3x01 = vec_sub(qxs01, qxh01); + vector signed char q3x02 = vec_sub(qxs02, qxh02); + vector signed char q3x03 = vec_sub(qxs03, qxh03); + vector signed char q3x10 = vec_sub(qxs10, qxh10); + vector signed char q3x11 = vec_sub(qxs11, qxh11); + vector signed char q3x12 = vec_sub(qxs12, qxh12); + vector signed char q3x13 = vec_sub(qxs13, qxh13); + + vector signed char q8y00 = vec_xl( 0, q8); + vector signed char q8y10 = vec_xl( 16, q8); + vector signed char q8y01 = vec_xl( 32, q8); + vector signed char q8y11 = vec_xl( 48, q8); + vector signed char q8y02 = vec_xl( 64, q8); + vector signed char q8y12 = vec_xl( 80, q8); + vector signed char q8y03 = vec_xl( 96, q8); + vector signed char q8y13 = vec_xl(112, q8); + q8 += 128; + + vector signed short vscales_h = vec_unpackh(vscales); + vector signed short vs0 = vec_splat(vscales_h, 0); + vector signed short vs1 = vec_splat(vscales_h, 1); + vector signed short vs2 = vec_splat(vscales_h, 2); + vector signed short vs3 = vec_splat(vscales_h, 3); + vector signed short vs4 = vec_splat(vscales_h, 4); + vector signed short vs5 = vec_splat(vscales_h, 5); + vector signed short vs6 = vec_splat(vscales_h, 6); + vector signed short vs7 = vec_splat(vscales_h, 7); + vscales = vec_sld(vscales, vscales, 8); + + vector signed short qv00 = vec_add(vec_mule(q3x00, q8y00), vec_mulo(q3x00, q8y00)); + vector signed short qv01 = vec_add(vec_mule(q3x01, q8y01), vec_mulo(q3x01, q8y01)); + vector signed short qv02 = vec_add(vec_mule(q3x02, q8y02), vec_mulo(q3x02, q8y02)); + vector signed short qv03 = vec_add(vec_mule(q3x03, q8y03), vec_mulo(q3x03, q8y03)); + vector signed short qv10 = vec_add(vec_mule(q3x10, q8y10), vec_mulo(q3x10, q8y10)); + vector signed short qv11 = vec_add(vec_mule(q3x11, q8y11), vec_mulo(q3x11, q8y11)); + vector signed short qv12 = vec_add(vec_mule(q3x12, q8y12), vec_mulo(q3x12, q8y12)); + vector signed short qv13 = vec_add(vec_mule(q3x13, q8y13), vec_mulo(q3x13, q8y13)); + + vsumi0 = vec_msum(qv00, vs0, vsumi0); + vsumi1 = vec_msum(qv01, vs2, vsumi1); + vsumi2 = vec_msum(qv02, vs4, vsumi2); + vsumi3 = vec_msum(qv03, vs6, vsumi3); + vsumi4 = vec_msum(qv10, vs1, vsumi4); + vsumi5 = vec_msum(qv11, vs3, vsumi5); + vsumi6 = vec_msum(qv12, vs5, vsumi6); + vsumi7 = vec_msum(qv13, vs7, vsumi7); + } + + vsumi0 = vec_add(vsumi0, vsumi4); + vsumi1 = vec_add(vsumi1, vsumi5); + vsumi2 = vec_add(vsumi2, vsumi6); + vsumi3 = vec_add(vsumi3, vsumi7); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#else + // scalar version + // This function is written like this so the compiler can manage to vectorize most of it + // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the + // manually vectorized version above. Every other version I tried would run at least 4 times slower. + // The ideal situation would be if we could just write the code once, and the compiler would + // automatically produce the best possible set of machine instructions, instead of us having to manually + // write vectorized versions for AVX, ARM_NEON, etc. + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + uint32_t auxs[4]; + const int8_t * scales = (const int8_t*)auxs; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT hm = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + q3 += 32; + } + a = aux8; + + memcpy(auxs, x[i].scales, 12); + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; + +#endif + +} + +void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#if defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed char lowMask1 = vec_splats((int8_t)0x3f); + const vector signed char lowMask2 = vec_splats((int8_t)0x30); + const vector int v0 = vec_splats((int32_t)0); + const vector unsigned char v2 = vec_splats((uint8_t)2); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin)); + vector float vdmin = vec_mul(vxmin, vyd); + + vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); + vector signed short q8ysums1 = vec_xl(16, y[i].bsums); + + UNUSED(kmask1); + UNUSED(kmask2); + UNUSED(kmask3); + UNUSED(utmp); + + vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); + vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2); + vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); + vector signed char u3 = vec_sr(u2, v4); + + vector signed char u30 = u1; + vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3); + + u1 = vec_and(u0, lowMask1); + u2 = vec_or(u30, u31); + + vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2); + + vector signed short vscales = vec_unpackh(utmps); + vector signed short q4xmins = vec_unpackl(utmps); + vector signed short q4xmins0 = vec_mergeh(q4xmins, q4xmins); + vector signed short q4xmins1 = vec_mergel(q4xmins, q4xmins); + + vector signed int prod0 = vec_mule(q4xmins0, q8ysums0); + vector signed int prod1 = vec_mule(q4xmins1, q8ysums1); + vector signed int prod2 = vec_mulo(q4xmins0, q8ysums0); + vector signed int prod3 = vec_mulo(q4xmins1, q8ysums1); + + vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0); + vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1); + vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); + vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + for (int j = 0; j < QK_K/64; j+=2) { + __builtin_prefetch(q4, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q4); + vector signed char qxs1 = (vector signed char)vec_xl(16, q4); + vector signed char qxs2 = (vector signed char)vec_xl(32, q4); + vector signed char qxs3 = (vector signed char)vec_xl(48, q4); + q4 += 64; + + vector unsigned char q4x00 = (vector unsigned char)vec_and(qxs0, lowMask); + vector unsigned char q4x01 = (vector unsigned char)vec_sr(qxs0, v4); + vector unsigned char q4x10 = (vector unsigned char)vec_and(qxs1, lowMask); + vector unsigned char q4x11 = (vector unsigned char)vec_sr(qxs1, v4); + vector unsigned char q4x20 = (vector unsigned char)vec_and(qxs2, lowMask); + vector unsigned char q4x21 = (vector unsigned char)vec_sr(qxs2, v4); + vector unsigned char q4x30 = (vector unsigned char)vec_and(qxs3, lowMask); + vector unsigned char q4x31 = (vector unsigned char)vec_sr(qxs3, v4); + + vector signed char q8y00 = vec_xl( 0, q8); + vector signed char q8y10 = vec_xl( 16, q8); + vector signed char q8y01 = vec_xl( 32, q8); + vector signed char q8y11 = vec_xl( 48, q8); + vector signed char q8y20 = vec_xl( 64, q8); + vector signed char q8y30 = vec_xl( 80, q8); + vector signed char q8y21 = vec_xl( 96, q8); + vector signed char q8y31 = vec_xl(112, q8); + q8 += 128; + + vector signed int qv00 = vec_msum(q8y00, q4x00, v0); + vector signed int qv01 = vec_msum(q8y01, q4x01, v0); + vector signed int qv10 = vec_msum(q8y10, q4x10, v0); + vector signed int qv11 = vec_msum(q8y11, q4x11, v0); + vector signed int qv20 = vec_msum(q8y20, q4x20, v0); + vector signed int qv21 = vec_msum(q8y21, q4x21, v0); + vector signed int qv30 = vec_msum(q8y30, q4x30, v0); + vector signed int qv31 = vec_msum(q8y31, q4x31, v0); + + vector signed int vscales_h = vec_unpackh(vscales); + vector signed int vs0 = vec_splat(vscales_h, 0); + vector signed int vs1 = vec_splat(vscales_h, 1); + vector signed int vs2 = vec_splat(vscales_h, 2); + vector signed int vs3 = vec_splat(vscales_h, 3); + vscales = vec_sld(vscales, vscales, 8); + + vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0); + vsumi1 = vec_add(vec_mul(qv01, vs1), vsumi1); + vsumi2 = vec_add(vec_mul(qv20, vs2), vsumi2); + vsumi3 = vec_add(vec_mul(qv21, vs3), vsumi3); + + vsumi0 = vec_add(vec_mul(qv10, vs0), vsumi0); + vsumi1 = vec_add(vec_mul(qv11, vs1), vsumi1); + vsumi2 = vec_add(vec_mul(qv30, vs2), vsumi2); + vsumi3 = vec_add(vec_mul(qv31, vs3), vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#else + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + a += 32; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + a += 32; q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#if defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed char lowMask1 = vec_splats((int8_t)0x3f); + const vector signed char lowMask2 = vec_splats((int8_t)0x30); + const vector int v0 = vec_splats((int32_t)0); + const vector unsigned char v1 = vec_splats((unsigned char)0x1); + const vector unsigned char v2 = vec_splats((unsigned char)0x2); + const vector unsigned char v3 = vec_splats((unsigned char)0x3); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin)); + vector float vdmin = vec_mul(vxmin, vyd); + + UNUSED(kmask1); + UNUSED(kmask2); + UNUSED(kmask3); + UNUSED(utmp); + + vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); + vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2); + vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); + vector signed char u3 = vec_sr(u2, v4); + + vector signed char u30 = u1; + vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3); + + u1 = vec_and(u0, lowMask1); + u2 = vec_or(u30, u31); + + vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2); + + vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); + vector signed short q8ysums1 = vec_xl(16, y[i].bsums); + + vector signed short vscales = vec_unpackh(utmps); + + vector signed short q5xmins = vec_unpackl(utmps); + vector signed short q5xmins0 = vec_mergeh(q5xmins, q5xmins); + vector signed short q5xmins1 = vec_mergel(q5xmins, q5xmins); + + vector signed int prod0 = vec_mule(q5xmins0, q8ysums0); + vector signed int prod1 = vec_mule(q5xmins1, q8ysums1); + vector signed int prod2 = vec_mulo(q5xmins0, q8ysums0); + vector signed int prod3 = vec_mulo(q5xmins1, q8ysums1); + + vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0); + vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1); + vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); + vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); + + vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].qh); + vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].qh); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + const uint8_t * GGML_RESTRICT q5 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + for (int j = 0; j < QK_K/64; ++j) { + __builtin_prefetch(q5, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q5); + vector signed char qxs1 = (vector signed char)vec_xl(16, q5); + q5 += 32; + + vector signed char qxs00 = vec_and(qxs0, lowMask); + vector signed char qxs01 = vec_sr(qxs0, v4); + vector signed char qxs10 = vec_and(qxs1, lowMask); + vector signed char qxs11 = vec_sr(qxs1, v4); + + vector signed char q5h00 = vec_sl(vec_and((vector signed char)v1, qxhs0), v4); + vector signed char q5h01 = vec_sl(vec_and((vector signed char)v2, qxhs0), v3); + vector signed char q5h10 = vec_sl(vec_and((vector signed char)v1, qxhs1), v4); + vector signed char q5h11 = vec_sl(vec_and((vector signed char)v2, qxhs1), v3); + qxhs0 = vec_sr(qxhs0, v2); + qxhs1 = vec_sr(qxhs1, v2); + + vector unsigned char q5x00 = (vector unsigned char)vec_or(q5h00, qxs00); + vector unsigned char q5x01 = (vector unsigned char)vec_or(q5h01, qxs01); + vector unsigned char q5x10 = (vector unsigned char)vec_or(q5h10, qxs10); + vector unsigned char q5x11 = (vector unsigned char)vec_or(q5h11, qxs11); + + vector signed char q8y00 = vec_xl( 0, q8); + vector signed char q8y10 = vec_xl(16, q8); + vector signed char q8y01 = vec_xl(32, q8); + vector signed char q8y11 = vec_xl(48, q8); + q8 += 64; + + vector signed int qv00 = vec_msum(q8y00, q5x00, v0); + vector signed int qv01 = vec_msum(q8y01, q5x01, v0); + vector signed int qv10 = vec_msum(q8y10, q5x10, v0); + vector signed int qv11 = vec_msum(q8y11, q5x11, v0); + + vector signed int vscales_h = vec_unpackh(vscales); + vector signed int vs0 = vec_splat(vscales_h, 0); + vector signed int vs1 = vec_splat(vscales_h, 1); + vscales = vec_sld(vscales, vscales, 12); + + vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0); + vsumi1 = vec_add(vec_mul(qv10, vs0), vsumi1); + vsumi2 = vec_add(vec_mul(qv01, vs1), vsumi2); + vsumi3 = vec_add(vec_mul(qv11, vs1), vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#else + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const uint8_t * GGML_RESTRICT hm = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector int v0 = vec_splats((int32_t)0); + const vector unsigned char v2 = vec_splats((unsigned char)0x2); + const vector unsigned char v3 = vec_splats((unsigned char)0x3); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + const vector unsigned char v6 = vec_splats((unsigned char)0x6); + const vector signed char off = vec_splats((signed char)0x20); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + vector signed int vsumi4 = v0; + vector signed int vsumi5 = v0; + vector signed int vsumi6 = v0; + vector signed int vsumi7 = v0; + + const uint8_t * GGML_RESTRICT q6 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT qs = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + for (int j = 0; j < QK_K/128; ++j) { + __builtin_prefetch(q6, 0, 0); + __builtin_prefetch(qh, 0, 0); + __builtin_prefetch(q8, 0, 0); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q6); + vector signed char qxs1 = (vector signed char)vec_xl(16, q6); + vector signed char qxs2 = (vector signed char)vec_xl(32, q6); + vector signed char qxs3 = (vector signed char)vec_xl(48, q6); + q6 += 64; + + vector signed char qxs00 = vec_and(qxs0, lowMask); + vector signed char qxs01 = vec_sr(qxs0, v4); + vector signed char qxs10 = vec_and(qxs1, lowMask); + vector signed char qxs11 = vec_sr(qxs1, v4); + vector signed char qxs20 = vec_and(qxs2, lowMask); + vector signed char qxs21 = vec_sr(qxs2, v4); + vector signed char qxs30 = vec_and(qxs3, lowMask); + vector signed char qxs31 = vec_sr(qxs3, v4); + + vector signed char qxhs0 = (vector signed char)vec_xl( 0, qh); + vector signed char qxhs1 = (vector signed char)vec_xl(16, qh); + qh += 32; + + vector signed char qxh00 = vec_sl(vec_and((vector signed char)v3, qxhs0), v4); + vector signed char qxh01 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v4)), v4); + vector signed char qxh10 = vec_sl(vec_and((vector signed char)v3, qxhs1), v4); + vector signed char qxh11 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v4)), v4); + vector signed char qxh20 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v2)), v4); + vector signed char qxh21 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v6)), v4); + vector signed char qxh30 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v2)), v4); + vector signed char qxh31 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v6)), v4); + + vector signed char q6x00 = vec_sub(vec_or(qxh00, qxs00), off); + vector signed char q6x01 = vec_sub(vec_or(qxh01, qxs01), off); + vector signed char q6x10 = vec_sub(vec_or(qxh10, qxs10), off); + vector signed char q6x11 = vec_sub(vec_or(qxh11, qxs11), off); + vector signed char q6x20 = vec_sub(vec_or(qxh20, qxs20), off); + vector signed char q6x21 = vec_sub(vec_or(qxh21, qxs21), off); + vector signed char q6x30 = vec_sub(vec_or(qxh30, qxs30), off); + vector signed char q6x31 = vec_sub(vec_or(qxh31, qxs31), off); + + vector signed char q8y00 = vec_xl( 0, q8); + vector signed char q8y10 = vec_xl( 16, q8); + vector signed char q8y20 = vec_xl( 32, q8); + vector signed char q8y30 = vec_xl( 48, q8); + vector signed char q8y01 = vec_xl( 64, q8); + vector signed char q8y11 = vec_xl( 80, q8); + vector signed char q8y21 = vec_xl( 96, q8); + vector signed char q8y31 = vec_xl(112, q8); + q8 += 128; + + vector signed short qv00 = vec_add(vec_mule(q6x00, q8y00), vec_mulo(q6x00, q8y00)); + vector signed short qv10 = vec_add(vec_mule(q6x10, q8y10), vec_mulo(q6x10, q8y10)); + vector signed short qv20 = vec_add(vec_mule(q6x20, q8y20), vec_mulo(q6x20, q8y20)); + vector signed short qv30 = vec_add(vec_mule(q6x30, q8y30), vec_mulo(q6x30, q8y30)); + vector signed short qv01 = vec_add(vec_mule(q6x01, q8y01), vec_mulo(q6x01, q8y01)); + vector signed short qv11 = vec_add(vec_mule(q6x11, q8y11), vec_mulo(q6x11, q8y11)); + vector signed short qv21 = vec_add(vec_mule(q6x21, q8y21), vec_mulo(q6x21, q8y21)); + vector signed short qv31 = vec_add(vec_mule(q6x31, q8y31), vec_mulo(q6x31, q8y31)); + + vector signed short vscales = vec_unpackh(vec_xl_len(qs, 8)); + qs += 8; + + vector signed short vs0 = vec_splat(vscales, 0); + vector signed short vs1 = vec_splat(vscales, 1); + vector signed short vs2 = vec_splat(vscales, 2); + vector signed short vs3 = vec_splat(vscales, 3); + vector signed short vs4 = vec_splat(vscales, 4); + vector signed short vs5 = vec_splat(vscales, 5); + vector signed short vs6 = vec_splat(vscales, 6); + vector signed short vs7 = vec_splat(vscales, 7); + + vsumi0 = vec_msum(qv00, vs0, vsumi0); + vsumi1 = vec_msum(qv01, vs4, vsumi1); + vsumi2 = vec_msum(qv10, vs1, vsumi2); + vsumi3 = vec_msum(qv11, vs5, vsumi3); + vsumi4 = vec_msum(qv20, vs2, vsumi4); + vsumi5 = vec_msum(qv21, vs6, vsumi5); + vsumi6 = vec_msum(qv30, vs3, vsumi6); + vsumi7 = vec_msum(qv31, vs7, vsumi7); + } + + vsumi0 = vec_add(vsumi0, vsumi4); + vsumi1 = vec_add(vsumi1, vsumi5); + vsumi2 = vec_add(vsumi2, vsumi6); + vsumi3 = vec_add(vsumi3, vsumi7); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q4 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + a += 128; + q4 += 64; + qh += 32; + } + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/16; ++j) { + int scale = x[i].scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +#if defined (__POWER9_VECTOR__) +static const int8_t keven_signs_q2xs[1024] = { + 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, + 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, + 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1, + 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1, + 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, + 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1, + 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1, + 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, + 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1, + 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1, + 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1, + 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1, + 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1, + 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1, + 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1, + 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1, + 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1, + 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1, + 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1, + 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1, + 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1, + 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1, + 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, + 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1, + 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1, + 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1, + 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, + 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1, + 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1, + 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1, + 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1, + 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, +}; +#endif + +void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__POWER9_VECTOR__) + const vector int v0 = vec_splats((int32_t)0); + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + const uint16_t * GGML_RESTRICT q2 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + for (int j = 0; j < QK_K/32; j += 2) { + __builtin_prefetch(q2, 0, 1); + __builtin_prefetch(q8, 0, 1); + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + memcpy(aux32, q2, 4*sizeof(uint32_t)); + q2 += 8; + + vector signed long long aux64x2_0 = {*(const int64_t *)(iq2xxs_grid + aux8[ 0]), *(const int64_t *)(iq2xxs_grid + aux8[ 1])}; + vector signed long long aux64x2_1 = {*(const int64_t *)(iq2xxs_grid + aux8[ 2]), *(const int64_t *)(iq2xxs_grid + aux8[ 3])}; + vector signed long long aux64x2_2 = {*(const int64_t *)(iq2xxs_grid + aux8[ 8]), *(const int64_t *)(iq2xxs_grid + aux8[ 9])}; + vector signed long long aux64x2_3 = {*(const int64_t *)(iq2xxs_grid + aux8[10]), *(const int64_t *)(iq2xxs_grid + aux8[11])}; + + vector signed long long vsigns0 = {*(const int64_t *)(signs64 + ((aux32[1] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 7) & 127))}; + vector signed long long vsigns1 = {*(const int64_t *)(signs64 + ((aux32[1] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 21) & 127))}; + vector signed long long vsigns2 = {*(const int64_t *)(signs64 + ((aux32[3] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 7) & 127))}; + vector signed long long vsigns3 = {*(const int64_t *)(signs64 + ((aux32[3] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 21) & 127))}; + + vector signed char q2x0 = (vector signed char)vec_mul((vector signed char)vsigns0, (vector signed char)aux64x2_0); + vector signed char q2x1 = (vector signed char)vec_mul((vector signed char)vsigns1, (vector signed char)aux64x2_1); + vector signed char q2x2 = (vector signed char)vec_mul((vector signed char)vsigns2, (vector signed char)aux64x2_2); + vector signed char q2x3 = (vector signed char)vec_mul((vector signed char)vsigns3, (vector signed char)aux64x2_3); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3)); + + const uint16_t ls0 = aux32[1] >> 28; + const uint16_t ls1 = aux32[3] >> 28; + + vector signed short vscales01 = vec_splats((int16_t)(2*ls0+1)); + vector signed short vscales23 = vec_splats((int16_t)(2*ls1+1)); + + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = 0.125f * vec_extract(vsumf0, 0); + +#else + + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * GGML_RESTRICT q2 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(aux32, q2, 2*sizeof(uint32_t)); + q2 += 4; + const uint32_t ls = 2*(aux32[1] >> 28) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); + const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls; + } + sumf += d * bsum; + } + *s = 0.125f * sumf; +#endif +} + +void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__POWER9_VECTOR__) + const vector int v0 = vec_splats((int32_t)0); + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + const uint16_t * GGML_RESTRICT q2 = x[i].qs; + const uint8_t * GGML_RESTRICT sc = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + for (int j = 0; j < QK_K/64; ++j) { + __builtin_prefetch(q2, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed long long aux64x2_0 = {*(const int64_t *)(iq2xs_grid + (q2[0] & 511)), *(const int64_t *)(iq2xs_grid + (q2[1] & 511))}; + vector signed long long aux64x2_1 = {*(const int64_t *)(iq2xs_grid + (q2[2] & 511)), *(const int64_t *)(iq2xs_grid + (q2[3] & 511))}; + vector signed long long aux64x2_2 = {*(const int64_t *)(iq2xs_grid + (q2[4] & 511)), *(const int64_t *)(iq2xs_grid + (q2[5] & 511))}; + vector signed long long aux64x2_3 = {*(const int64_t *)(iq2xs_grid + (q2[6] & 511)), *(const int64_t *)(iq2xs_grid + (q2[7] & 511))}; + + vector signed long long vsigns0 = {*(const int64_t *)(signs64 + ((q2[0] >> 9))), *(const int64_t *)(signs64 + ((q2[1] >> 9)))}; + vector signed long long vsigns1 = {*(const int64_t *)(signs64 + ((q2[2] >> 9))), *(const int64_t *)(signs64 + ((q2[3] >> 9)))}; + vector signed long long vsigns2 = {*(const int64_t *)(signs64 + ((q2[4] >> 9))), *(const int64_t *)(signs64 + ((q2[5] >> 9)))}; + vector signed long long vsigns3 = {*(const int64_t *)(signs64 + ((q2[6] >> 9))), *(const int64_t *)(signs64 + ((q2[7] >> 9)))}; + q2 += 8; + + vector signed char q2x0 = (vector signed char)vec_mul((vector signed char)vsigns0, (vector signed char)aux64x2_0); + vector signed char q2x1 = (vector signed char)vec_mul((vector signed char)vsigns1, (vector signed char)aux64x2_1); + vector signed char q2x2 = (vector signed char)vec_mul((vector signed char)vsigns2, (vector signed char)aux64x2_2); + vector signed char q2x3 = (vector signed char)vec_mul((vector signed char)vsigns3, (vector signed char)aux64x2_3); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3)); + + const uint16_t ls0 = (uint16_t)(sc[0] & 0xf); + const uint16_t ls1 = (uint16_t)(sc[0] >> 4); + const uint16_t ls2 = (uint16_t)(sc[1] & 0xf); + const uint16_t ls3 = (uint16_t)(sc[1] >> 4); + sc += 2; + + vector signed short vscales0 = vec_splats((int16_t)(2*ls0+1)); + vector signed short vscales1 = vec_splats((int16_t)(2*ls1+1)); + vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1)); + vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1)); + + vsumi0 = vec_msum(qv0, vscales0, vsumi0); + vsumi1 = vec_msum(qv1, vscales1, vsumi1); + vsumi2 = vec_msum(qv2, vscales2, vsumi2); + vsumi3 = vec_msum(qv3, vscales3, vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = 0.125f * vec_extract(vsumf0, 0); + +#else + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * GGML_RESTRICT q2 = x[i].qs; + const uint8_t * GGML_RESTRICT sc = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1; + const uint16_t ls2 = 2*(sc[ib32] >> 4) + 1; + int32_t sumi = 0; + for (int l = 0; l < 2; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); + const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls1; + sumi = 0; + for (int l = 2; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); + const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls2; + q2 += 4; + } + sumf += d * bsum; + } + *s = 0.125f * sumf; +#endif +} + +void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__POWER9_VECTOR__) + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; + + const vector int v0 = vec_splats((int32_t)0); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + const vector unsigned char mask0 = vec_xl( 0, k_mask1); + const vector unsigned char mask1 = vec_xl(16, k_mask1); + const vector signed char mask2 = (vector signed char)vec_xl( 0, k_mask2); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + const uint8_t * GGML_RESTRICT q2 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8); + const uint8_t * GGML_RESTRICT sc = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + for (int j = 0; j < QK_K/32; j += 2) { + __builtin_prefetch(q2, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed long long aux64x2_0 = {*(const int64_t *)(iq2s_grid + (q2[0] | ((qh[0] << 8) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[1] | ((qh[0] << 6) & 0x300)))}; + vector signed long long aux64x2_1 = {*(const int64_t *)(iq2s_grid + (q2[2] | ((qh[0] << 4) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[3] | ((qh[0] << 2) & 0x300)))}; + vector signed long long aux64x2_2 = {*(const int64_t *)(iq2s_grid + (q2[4] | ((qh[1] << 8) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[5] | ((qh[1] << 6) & 0x300)))}; + vector signed long long aux64x2_3 = {*(const int64_t *)(iq2s_grid + (q2[6] | ((qh[1] << 4) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[7] | ((qh[1] << 2) & 0x300)))}; + q2 += 8; + qh += 2; + + vector signed char vsigns01 = (vector signed char)vec_splats(*(const uint32_t *)&signs[0]); + vector signed char vsigns23 = (vector signed char)vec_splats(*(const uint32_t *)&signs[2]); + signs += 4; + + vector signed char vsigns0 = vec_perm(vsigns01, vsigns01, mask0); + vector signed char vsigns1 = vec_perm(vsigns01, vsigns01, mask1); + vector signed char vsigns2 = vec_perm(vsigns23, vsigns23, mask0); + vector signed char vsigns3 = vec_perm(vsigns23, vsigns23, mask1); + + vsigns0 = (vector signed char)vec_cmpeq(vec_and(vsigns0, mask2), mask2); + vsigns1 = (vector signed char)vec_cmpeq(vec_and(vsigns1, mask2), mask2); + vsigns2 = (vector signed char)vec_cmpeq(vec_and(vsigns2, mask2), mask2); + vsigns3 = (vector signed char)vec_cmpeq(vec_and(vsigns3, mask2), mask2); + + vector signed char q2x0 = vec_sub(vec_xor(vsigns0, (vector signed char)aux64x2_0), vsigns0); + vector signed char q2x1 = vec_sub(vec_xor(vsigns1, (vector signed char)aux64x2_1), vsigns1); + vector signed char q2x2 = vec_sub(vec_xor(vsigns2, (vector signed char)aux64x2_2), vsigns2); + vector signed char q2x3 = vec_sub(vec_xor(vsigns3, (vector signed char)aux64x2_3), vsigns3); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3)); + + const uint16_t ls0 = (uint16_t)(sc[0] & 0xf); + const uint16_t ls1 = (uint16_t)(sc[0] >> 4); + const uint16_t ls2 = (uint16_t)(sc[1] & 0xf); + const uint16_t ls3 = (uint16_t)(sc[1] >> 4); + sc += 2; + + vector signed short vscales0 = vec_splats((int16_t)(2*ls0+1)); + vector signed short vscales1 = vec_splats((int16_t)(2*ls1+1)); + vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1)); + vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1)); + + vsumi0 = vec_msum(qv0, vscales0, vsumi0); + vsumi1 = vec_msum(qv1, vscales1, vsumi1); + vsumi2 = vec_msum(qv2, vscales2, vsumi2); + vsumi3 = vec_msum(qv3, vscales3, vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = 0.125f * vec_extract(vsumf0, 0); + +#else + + float sumf = 0; + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint8_t * signs = qs + QK_K/8; + + int bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + int ls1 = 1 + 2*(x[i].scales[ib32] & 0xf); + int ls2 = 1 + 2*(x[i].scales[ib32] >> 4); + int sumi1 = 0, sumi2 = 0; + for (int l = 0; l < 2; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); + for (int j = 0; j < 8; ++j) { + sumi1 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + for (int l = 2; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); + for (int j = 0; j < 8; ++j) { + sumi2 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += ls1 * sumi1 + ls2 * sumi2; + qs += 4; + signs += 4; + } + + sumf += d * bsum; + } + + *s = 0.125f * sumf; + +#endif + +} + +void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq3_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__POWER9_VECTOR__) + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + const vector int v0 = vec_splats((int32_t)0); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint32_t * GGML_RESTRICT signs = (const uint32_t *)(x[i].qs + QK_K/4); + const int8_t * GGML_RESTRICT q8 = y[i].qs; + +#pragma GCC unroll 1 + for (int j = 0; j < QK_K/32; j += 2) { + __builtin_prefetch(q3, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector unsigned int aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]}; + vector unsigned int aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]}; + vector unsigned int aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]}; + vector unsigned int aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]}; + q3 += 16; + + vector unsigned long long aux64x2_0 = {(uint64_t)(signs64[(signs[0] >> 0) & 127]), (uint64_t)(signs64[(signs[0] >> 7) & 127])}; + vector unsigned long long aux64x2_1 = {(uint64_t)(signs64[(signs[0] >> 14) & 127]), (uint64_t)(signs64[(signs[0] >> 21) & 127])}; + vector unsigned long long aux64x2_2 = {(uint64_t)(signs64[(signs[1] >> 0) & 127]), (uint64_t)(signs64[(signs[1] >> 7) & 127])}; + vector unsigned long long aux64x2_3 = {(uint64_t)(signs64[(signs[1] >> 14) & 127]), (uint64_t)(signs64[(signs[1] >> 21) & 127])}; + + vector signed char q3x0 = vec_mul((vector signed char)aux64x2_0, (vector signed char)aux32x4_0); + vector signed char q3x1 = vec_mul((vector signed char)aux64x2_1, (vector signed char)aux32x4_1); + vector signed char q3x2 = vec_mul((vector signed char)aux64x2_2, (vector signed char)aux32x4_2); + vector signed char q3x3 = vec_mul((vector signed char)aux64x2_3, (vector signed char)aux32x4_3); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q3x0, q8y0), vec_mulo(q3x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q3x1, q8y1), vec_mulo(q3x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q3x2, q8y2), vec_mulo(q3x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q3x3, q8y3), vec_mulo(q3x3, q8y3)); + + const uint16_t ls0 = (uint16_t)(signs[0] >> 28); + const uint16_t ls1 = (uint16_t)(signs[1] >> 28); + signs += 2; + + vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); + vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); + + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = 0.25f * vec_extract(vsumf0, 0); + +#else + + uint32_t aux32; + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t); + const uint32_t ls = 2*(aux32 >> 28) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]); + const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]); + const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127]; + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + q3 += 8; + bsum += sumi * ls; + } + sumf += d * bsum; + } + *s = 0.25f * sumf; +#endif +} + +void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq3_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__POWER9_VECTOR__) + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; + + const vector int v0 = vec_splats((int32_t)0); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + const vector unsigned char mask0 = vec_xl( 0, k_mask1); + const vector unsigned char mask1 = vec_xl(16, k_mask1); + const vector signed char mask2 = (vector signed char)vec_xl( 0, k_mask2); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].signs); + const uint8_t * GGML_RESTRICT sc = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + for (int j = 0; j < QK_K/32; j += 2) { + __builtin_prefetch(q3, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector unsigned int aux32x4_0 = {iq3s_grid[q3[ 0] | ((qh[0] << 8) & 256)], iq3s_grid[q3[ 1] | ((qh[0] << 7) & 256)], + iq3s_grid[q3[ 2] | ((qh[0] << 6) & 256)], iq3s_grid[q3[ 3] | ((qh[0] << 5) & 256)]}; + vector unsigned int aux32x4_1 = {iq3s_grid[q3[ 4] | ((qh[0] << 4) & 256)], iq3s_grid[q3[ 5] | ((qh[0] << 3) & 256)], + iq3s_grid[q3[ 6] | ((qh[0] << 2) & 256)], iq3s_grid[q3[ 7] | ((qh[0] << 1) & 256)]}; + vector unsigned int aux32x4_2 = {iq3s_grid[q3[ 8] | ((qh[1] << 8) & 256)], iq3s_grid[q3[ 9] | ((qh[1] << 7) & 256)], + iq3s_grid[q3[10] | ((qh[1] << 6) & 256)], iq3s_grid[q3[11] | ((qh[1] << 5) & 256)]}; + vector unsigned int aux32x4_3 = {iq3s_grid[q3[12] | ((qh[1] << 4) & 256)], iq3s_grid[q3[13] | ((qh[1] << 3) & 256)], + iq3s_grid[q3[14] | ((qh[1] << 2) & 256)], iq3s_grid[q3[15] | ((qh[1] << 1) & 256)]}; + q3 += 16; + qh += 2; + + vector signed char vsigns01 = (vector signed char)vec_splats(*(const uint32_t *)&signs[0]); + vector signed char vsigns02 = (vector signed char)vec_splats(*(const uint32_t *)&signs[2]); + signs += 4; + + vector signed char vsigns0 = vec_perm(vsigns01, vsigns01, mask0); + vector signed char vsigns1 = vec_perm(vsigns01, vsigns01, mask1); + vector signed char vsigns2 = vec_perm(vsigns02, vsigns02, mask0); + vector signed char vsigns3 = vec_perm(vsigns02, vsigns02, mask1); + + vsigns0 = (vector signed char)vec_cmpeq(vec_and(vsigns0, mask2), mask2); + vsigns1 = (vector signed char)vec_cmpeq(vec_and(vsigns1, mask2), mask2); + vsigns2 = (vector signed char)vec_cmpeq(vec_and(vsigns2, mask2), mask2); + vsigns3 = (vector signed char)vec_cmpeq(vec_and(vsigns3, mask2), mask2); + + vector signed char q3x0 = vec_sub(vec_xor(vsigns0, (vector signed char)aux32x4_0), vsigns0); + vector signed char q3x1 = vec_sub(vec_xor(vsigns1, (vector signed char)aux32x4_1), vsigns1); + vector signed char q3x2 = vec_sub(vec_xor(vsigns2, (vector signed char)aux32x4_2), vsigns2); + vector signed char q3x3 = vec_sub(vec_xor(vsigns3, (vector signed char)aux32x4_3), vsigns3); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q3x0, q8y0), vec_mulo(q3x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q3x1, q8y1), vec_mulo(q3x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q3x2, q8y2), vec_mulo(q3x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q3x3, q8y3), vec_mulo(q3x3, q8y3)); + + const uint16_t ls0 = (uint16_t)(sc[0] & 0xf); + const uint16_t ls1 = (uint16_t)(sc[0] >> 4); + sc ++; + + vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); + vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); + + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#else + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT signs = x[i].signs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1; + const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + qs += 8; + signs += 4; + bsum += sumi * ls1; + sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + qs += 8; + signs += 4; + bsum += sumi * ls2; + } + sumf += d * bsum; + } + *s = sumf; +#endif +} + +void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__POWER9_VECTOR__) + const vector unsigned char v0 = vec_splats((unsigned char)0x0); + const vector unsigned short vsign = vec_splats((unsigned short)0x8000); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + for (int i = 0; i < nb; ++i) { + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); + vector float vyd = vec_splats(y[i].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = vec_splats((int32_t)0); + vector signed int vsumi1 = vec_splats((int32_t)0); + vector signed int vsumi2 = vec_splats((int32_t)0); + vector signed int vsumi3 = vec_splats((int32_t)0); + vector signed int vsumi8 = vec_splats((int32_t)0); + + const uint8_t * GGML_RESTRICT q1 = x[i].qs; + const uint16_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + const int16_t * GGML_RESTRICT qs = y[i].bsums; + + for (int j = 0; j < QK_K/32; j += 2) { + __builtin_prefetch(q1, 0, 1); + __builtin_prefetch(qh, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed long long aux64x2_0 = {*(const int64_t *)(iq1s_grid + (q1[0] | ((qh[0] << 8) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[1] | ((qh[0] << 5) & 0x700)))}; + vector signed long long aux64x2_1 = {*(const int64_t *)(iq1s_grid + (q1[2] | ((qh[0] << 2) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[3] | ((qh[0] >> 1) & 0x700)))}; + vector signed long long aux64x2_2 = {*(const int64_t *)(iq1s_grid + (q1[4] | ((qh[1] << 8) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[5] | ((qh[1] << 5) & 0x700)))}; + vector signed long long aux64x2_3 = {*(const int64_t *)(iq1s_grid + (q1[6] | ((qh[1] << 2) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[7] | ((qh[1] >> 1) & 0x700)))}; + q1 += 8; + + vector signed char q1x0 = (vector signed char)aux64x2_0; + vector signed char q1x1 = (vector signed char)aux64x2_1; + vector signed char q1x2 = (vector signed char)aux64x2_2; + vector signed char q1x3 = (vector signed char)aux64x2_3; + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q1x0, q8y0), vec_mulo(q1x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q1x1, q8y1), vec_mulo(q1x1, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q1x2, q8y2), vec_mulo(q1x2, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q1x3, q8y3), vec_mulo(q1x3, q8y3)); + + const uint16_t ls0 = (uint16_t)((qh[0] >> 12) & 7); + const uint16_t ls1 = (uint16_t)((qh[1] >> 12) & 7); + + vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); + vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); + vector signed short vscales = vec_sld(vscales23, vscales01, 8); + + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); + + vector signed short q8ysums = vec_xl_len(qs, 8); + qs += 4; + q8ysums = vec_mergeh(q8ysums, (vector signed short)v0); + + vector signed short qxh = (vector signed short)vec_sld(vec_splats(qh[1]), vec_splats(qh[0]), 8); + qh += 2; + vector __bool short vsel = vec_cmpge(qxh, (vector signed short)v0); + + vector signed short q8ysum = vec_sel((vector signed short)vec_xor((vector unsigned short)q8ysums, vsign), q8ysums, vsel); + + vsumi8 = vec_add(vec_mule(q8ysum, vscales), vsumi8); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + + vsumf0 = vec_madd(vec_ctf(vsumi8, 0), vec_mul(vd, vec_splats(IQ1S_DELTA)), vsumf0); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#else + + float sumf = 0; + for (int i = 0; i < nb; i++) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + int sumi = 0, sumi1 = 0; + for (int ib = 0; ib < QK_K/32; ++ib) { + const int ls = 2*((qh[ib] >> 12) & 7) + 1; + const int delta = qh[ib] & 0x8000 ? -1 : 1; + int lsum = 0; + for (int l = 0; l < 4; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8))); + for (int j = 0; j < 8; ++j) { + lsum += q8[j] * grid[j]; + } + q8 += 8; + } + sumi += ls * lsum; + sumi1 += ls * delta * (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]); + qs += 4; + } + + sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); + } + + *s = sumf; + +#endif +} + +void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK4_NL == 0); + static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same"); + + const block_iq4_nl * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + const int nb = n / QK4_NL; + + int ib = 0; + float sumf = 0; + +#if defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector signed int v0 = vec_splats((int32_t)0); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + + const vector signed char values = vec_xl( 0, kvalues_iq4nl); + +#pragma GCC unroll 4 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); + vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); + vector float vd = vec_mul(vxd, vyd); + + vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); + vector signed char q4x0 = vec_and(qxs, lowMask); + vector signed char q4x1 = vec_sr(qxs, v4); + + q4x0 = vec_perm(values, values, (vector unsigned char)q4x0); + q4x1 = vec_perm(values, values, (vector unsigned char)q4x1); + + vector signed char q8y0 = vec_xl( 0, y[ib].qs); + vector signed char q8y1 = vec_xl(16, y[ib].qs); + + vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1)); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + + vsumi0 = vec_sum4s(qv0, vsumi0); + vsumi1 = vec_sum4s(qv1, vsumi1); + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + } + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + sumf = vec_extract(vsumf0, 0); + +#endif + for (; ib < nb; ++ib) { + const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d); + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < QK4_NL/2; ++j) { + sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf]; + sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4]; + } + sumf += d * (sumi1 + sumi2); + } + *s = sumf; +} + +void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); + + const block_iq4_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__POWER9_VECTOR__) + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector int v0 = vec_splats((int32_t)0); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + + vector float vsumf0 = vec_splats(0.0f); + vector float vsumf1 = vec_splats(0.0f); + vector float vsumf2 = vec_splats(0.0f); + vector float vsumf3 = vec_splats(0.0f); + + const vector signed char values = vec_xl( 0, kvalues_iq4nl); + + for (int ibl = 0; ibl < nb; ++ibl) { + + vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ibl].d)); + vector float vyd = vec_splats(y[ibl].d); + vector float vd = vec_mul(vxd, vyd); + + vector signed int vsumi0 = v0; + vector signed int vsumi1 = v0; + vector signed int vsumi2 = v0; + vector signed int vsumi3 = v0; + + uint16_t h = x[ibl].scales_h; + + const uint8_t * GGML_RESTRICT q4 = x[ibl].qs; + const uint8_t * GGML_RESTRICT sc = x[ibl].scales_l; + const int8_t * GGML_RESTRICT q8 = y[ibl].qs; + + for (int ib = 0; ib < QK_K/64; ib ++ ) { + __builtin_prefetch(q4, 0, 1); + __builtin_prefetch(q8, 0, 1); + + vector signed char qxs0 = (vector signed char)vec_xl( 0, q4); + vector signed char qxs1 = (vector signed char)vec_xl(16, q4); + q4 += 32; + + vector signed char q4x00 = (vector signed char)vec_and(qxs0, lowMask); + vector signed char q4x01 = (vector signed char)vec_sr(qxs0, v4); + vector signed char q4x10 = (vector signed char)vec_and(qxs1, lowMask); + vector signed char q4x11 = (vector signed char)vec_sr(qxs1, v4); + + q4x00 = vec_perm(values, values, (vector unsigned char)q4x00); + q4x01 = vec_perm(values, values, (vector unsigned char)q4x01); + q4x10 = vec_perm(values, values, (vector unsigned char)q4x10); + q4x11 = vec_perm(values, values, (vector unsigned char)q4x11); + + vector signed char q8y0 = vec_xl( 0, q8); + vector signed char q8y1 = vec_xl(16, q8); + vector signed char q8y2 = vec_xl(32, q8); + vector signed char q8y3 = vec_xl(48, q8); + q8 += 64; + + vector signed short qv0 = vec_add(vec_mule(q4x00, q8y0), vec_mulo(q4x00, q8y0)); + vector signed short qv1 = vec_add(vec_mule(q4x01, q8y1), vec_mulo(q4x01, q8y1)); + vector signed short qv2 = vec_add(vec_mule(q4x10, q8y2), vec_mulo(q4x10, q8y2)); + vector signed short qv3 = vec_add(vec_mule(q4x11, q8y3), vec_mulo(q4x11, q8y3)); + + const uint16_t ls0 = (uint16_t)(((sc[0] & 0xf) | ((h << 4) & 0x30)) - 32); + const uint16_t ls1 = (uint16_t)(((sc[0] >> 4) | ((h << 2) & 0x30)) - 32); + h >>= 4; + sc ++; + + vector signed short vscales01 = vec_splats((int16_t)ls0); + vector signed short vscales23 = vec_splats((int16_t)ls1); + + vsumi0 = vec_msum(qv0, vscales01, vsumi0); + vsumi1 = vec_msum(qv1, vscales01, vsumi1); + vsumi2 = vec_msum(qv2, vscales23, vsumi2); + vsumi3 = vec_msum(qv3, vscales23, vsumi3); + } + + vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); + vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); + vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); + vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); + } + + vsumf0 = vec_add(vsumf0, vsumf2); + vsumf1 = vec_add(vsumf1, vsumf3); + + vsumf0 = vec_add(vsumf0, vsumf1); + + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); + vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); + + *s = vec_extract(vsumf0, 0); + +#else + float sumf = 0; + for (int ibl = 0; ibl < nb; ++ibl) { + const float d4d8 = GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d; + uint16_t h = x[ibl].scales_h; + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + for (int ib = 0; ib < QK_K/32; ib += 2) { + const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30); + const uint8_t ls2 = (x[ibl].scales_l[ib/2] >> 4) | ((h << 2) & 0x30); + h >>= 4; + const float d1 = d4d8*(ls1 - 32); + const float d2 = d4d8*(ls2 - 32); + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 16; ++j) { + sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; + sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; + } + sumf += d1 * (sumi1 + sumi2); + qs += 16; + q8 += 32; + sumi1 = sumi2 = 0; + for (int j = 0; j < 16; ++j) { + sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; + sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; + } + sumf += d2 * (sumi1 + sumi2); + qs += 16; + q8 += 32; + } + } + *s = sumf; +#endif +} + diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c new file mode 100644 index 000000000..6f3aa94fb --- /dev/null +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -0,0 +1,2068 @@ +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" +#include "ggml-quants.h" +#include "ggml-impl.h" +#include "ggml-cpu.h" + +#include "../../quants.h" +#include "../../ggml-cpu-impl.h" + +#include +#include +#include +#include +#include // for qsort +#include // for GGML_ASSERT + +#define GROUP_MAX_EPS 1e-15f +#define GROUP_MAX_EPS_IQ3_XXS 1e-8f +#define GROUP_MAX_EPS_IQ2_S 1e-8f +#define GROUP_MAX_EPS_IQ1_M 1e-7f +#define GROUP_MAX_EPS_IQ1_S 1e-12f + +#define UNUSED GGML_UNUSED + +void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0 * GGML_RESTRICT y = vy; + +#if defined(__riscv_v) + + size_t vl = QK8_0; + + for (int i = 0; i < nb; i++) { + // load elements + vfloat32m8_t v_x = __riscv_vle32_v_f32m8(x+i*QK8_0, vl); + + vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl); + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl); + float amax = __riscv_vfmv_f_s_f32m1_f32(vmax); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl); + + // convert to integer + vint16m4_t vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl); + vint8m2_t vs = __riscv_vncvt_x_x_w_i8m2(vi, vl); + + // store result + __riscv_vse8_v_i8m2(y[i].qs , vs, vl); + } +#else + GGML_UNUSED(nb); + // scalar + quantize_row_q8_0_ref(x, y, k); +#endif +} + +void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % QK8_1 == 0); + const int nb = k / QK8_1; + + block_q8_1 * GGML_RESTRICT y = vy; + +#if defined(__riscv_v) + + size_t vl = QK8_1; + + for (int i = 0; i < nb; i++) { + // load elements + vfloat32m8_t v_x = __riscv_vle32_v_f32m8(x+i*QK8_1, vl); + + vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl); + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl); + vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl); + float amax = __riscv_vfmv_f_s_f32m1_f32(vmax); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl); + + // convert to integer + vint16m4_t vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl); + vint8m2_t vs = __riscv_vncvt_x_x_w_i8m2(vi, vl); + + // store result + __riscv_vse8_v_i8m2(y[i].qs , vs, vl); + + // compute sum for y[i].s + vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t vwrs = __riscv_vwredsum_vs_i8m2_i16m1(vs, tmp2, vl); + + // set y[i].s + int sum = __riscv_vmv_x_s_i16m1_i16(vwrs); + y[i].s = GGML_FP32_TO_FP16(sum*d); + } + +#else + GGML_UNUSED(nb); + // scalar + quantize_row_q8_1_ref(x, y, k); +#endif +} + +//===================================== Dot products ================================= + +void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + int ib = 0; + float sumf = 0; + +#if defined(__riscv_v) + size_t vl = qk / 2; + + for (; ib < nb; ++ib) { + // load elements + vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl); + + vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl); + vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl); + + // mask and store lower part of x, and then upper part + vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl); + vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); + + vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a); + vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l); + + // subtract offset + vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl); + vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl); + + vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); + vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl); + + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); + + sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d); + } + +#endif + for (; ib < nb; ++ib) { + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[ib].qs[j] & 0x0F) - 8; + const int v1 = (x[ib].qs[j] >> 4) - 8; + + sumi0 += (v0 * y[ib].qs[j]); + sumi1 += (v1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d); + } + + *s = sumf; +} + +void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_1 * GGML_RESTRICT x = vx; + const block_q8_1 * GGML_RESTRICT y = vy; + + int ib = 0; + float sumf = 0; + +#if defined(__riscv_v) + size_t vl = qk / 2; + + for (; ib < nb; ++ib) { + // load elements + vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl); + + vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl); + vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl); + + // mask and store lower part of x, and then upper part + vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl); + vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); + + vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a); + vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l); + + vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); + vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl); + + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); + + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); + } + +#endif + for (; ib < nb; ++ib) { + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[ib].qs[j] & 0x0F); + const int v1 = (x[ib].qs[j] >> 4); + + sumi0 += (v0 * y[ib].qs[j]); + sumi1 += (v1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); + } + + *s = sumf; +} + +void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + int ib = 0; + float sumf = 0; + + assert(n % qk == 0); + assert(qk == QK5_0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + +#if defined(__riscv_v) + size_t vl; + size_t vlenb = __riscv_vlenb(); + + for (; ib < nb; ++ib) { + vl = qk / 2; + vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl); + vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl)); + vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl)); + vint8m2_t v0c; + if (vlenb == 16) { + v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h); + } else { + v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32); + v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l); + } + + vl = qk; + vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl); + qh = __riscv_vmnand_mm_b4(qh, qh, vl); + vint8m2_t v0f = __riscv_vsub_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl); + vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl); + vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl); + vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl); + int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum); + + sumf += (GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)) * sumi; + } + +#endif + for (; ib < nb; ++ib) { + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); + + const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16); + const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16); + + sumi0 += (x0 * y[ib].qs[j]); + sumi1 += (x1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi; + } + + *s = sumf; +} + +void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; + + int ib = 0; + float sumf = 0; + + assert(n % qk == 0); + assert(qk == QK5_1); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_1 * GGML_RESTRICT x = vx; + const block_q8_1 * GGML_RESTRICT y = vy; + +#if defined(__riscv_v) + size_t vl; + size_t vlenb = __riscv_vlenb(); + + for (; ib < nb; ++ib) { + vl = qk / 2; + vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl); + vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl)); + vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl)); + vint8m2_t v0c; + if (vlenb == 16) { + v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h); + } else { + v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32); + v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l); + } + + vl = qk; + vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl); + vint8m2_t v0f = __riscv_vor_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl); + vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl); + vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl); + vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl); + int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum); + + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); + } + +#endif + for (; ib < nb; ++ib) { + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; + + const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0; + const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1; + + sumi0 += (x0 * y[ib].qs[j]); + sumi1 += (x1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); + } + + *s = sumf; +} + +void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q8_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + int ib = 0; + float sumf = 0; + +#if defined(__riscv_v) + size_t vl = qk; + + for (; ib < nb; ++ib) { + // load elements + vint8m2_t bx_0 = __riscv_vle8_v_i8m2(x[ib].qs, vl); + vint8m2_t by_0 = __riscv_vle8_v_i8m2(y[ib].qs, vl); + + vint16m4_t vw_mul = __riscv_vwmul_vv_i16m4(bx_0, by_0, vl); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t v_sum = __riscv_vwredsum_vs_i16m4_i32m1(vw_mul, v_zero, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum); + + sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)); + } + +#endif + for (; ib < nb; ++ib) { + int sumi = 0; + + for (int j = 0; j < qk; j++) { + sumi += x[ib].qs[j]*y[ib].qs[j]; + } + + sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)); + } + + *s = sumf; +} + +void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q2_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined __riscv_xtheadvector + + float sumf = 0; + uint8_t atmp[16]; + + for (int i = 0; i < nb; ++i) { + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + uint8_t *patmp = atmp; + int vsums; + int tmp; + __asm__ __volatile__( + "th.vsetvli zero, %[vl16], e8, m1\n\t" + "th.vmv.v.x v8, zero\n\t" + "th.vlb.v v1, (%[sc])\n\t" + "th.vand.vi v0, v1, 0xF\n\t" + "th.vsrl.vi v1, v1, 4\n\t" + "th.vsb.v v0, (%[scale])\n\t" + "th.vwaddu.vx v16, v1, zero\n\t" + "th.vsetvli zero, %[vl16], e16, m2\n\t" + "th.vlh.v v2, (%[bsums])\n\t" + "th.vwmul.vv v4, v16, v2\n\t" + "th.vsetvli zero, %[vl16], e32, m4\n\t" + "th.vredsum.vs v8, v4, v8\n\t" + "th.vmv.x.s %[vsums], v8" + : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums) + : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums) + , [vl16] "r" (16) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + sumf += dmin * vsums; + int isum = 0; + + for (int j = 0; j < QK_K/128; ++j) { + __asm__ __volatile__( + "th.vsetvli zero, %[vl32], e8, m2\n\t" + "th.vlb.v v0, (%[q2])\n\t" + "th.vsrl.vi v2, v0, 2\n\t" + "th.vsrl.vi v4, v0, 4\n\t" + "th.vsrl.vi v6, v0, 6\n\t" + "th.vand.vi v0, v0, 0x3\n\t" + "th.vand.vi v2, v2, 0x3\n\t" + "th.vand.vi v4, v4, 0x3\n\t" + "th.vsetvli zero, %[vl128], e8, m8\n\t" + "th.vlb.v v8, (%[q8])\n\t" + "th.vsetvli zero, %[vl64], e8, m4\n\t" + "th.vwmul.vv v16, v0, v8\n\t" + "th.vwmul.vv v24, v4, v12\n\t" + "th.vsetvli zero, %[vl16], e16, m2\n\t" + "th.vmv.v.x v0, zero\n\t" + "th.vwredsum.vs v10, v16, v0\n\t" + "th.vwredsum.vs v9, v18, v0\n\t" + "th.vwredsum.vs v8, v20, v0\n\t" + "th.vwredsum.vs v7, v22, v0\n\t" + "th.vwredsum.vs v11, v24, v0\n\t" + "th.vwredsum.vs v12, v26, v0\n\t" + "th.vwredsum.vs v13, v28, v0\n\t" + "th.vwredsum.vs v14, v30, v0\n\t" + "li %[tmp], 4\n\t" + "th.vsetvli zero, %[tmp], e32, m1\n\t" + "th.vslideup.vi v10, v9, 1\n\t" + "th.vslideup.vi v8, v7, 1\n\t" + "th.vslideup.vi v11, v12, 1\n\t" + "th.vslideup.vi v13, v14, 1\n\t" + "th.vslideup.vi v10, v8, 2\n\t" + "th.vslideup.vi v11, v13, 2\n\t" + "li %[tmp], 8\n\t" + "th.vsetvli zero, %[tmp], e32, m2\n\t" + "th.vlbu.v v12, (%[scale])\n\t" + "th.vmul.vv v10, v10, v12\n\t" + "th.vredsum.vs v0, v10, v0\n\t" + "th.vmv.x.s %[tmp], v0\n\t" + "add %[isum], %[isum], %[tmp]" + : [tmp] "=&r" (tmp), [isum] "+&r" (isum) + : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8) + , [vl16] "r" (16), [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q2 += 32; q8 += 128; patmp += 8; + } + + sumf += dall * isum; + } + + *s = sumf; + +#elif defined __riscv_v + + float sumf = 0; + uint8_t atmp[16]; + + const int vector_length = __riscv_vlenb() * 8; + uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }; + + switch (vector_length) { + case 256: + for (int i = 0; i < nb; ++i) { + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + size_t vl = 16; + + vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); + vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); + + vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); + + vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); + vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); + vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); + vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); + vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + + sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); + + vl = 32; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); + + uint8_t is = 0; + int isum = 0; + + for (int j = 0; j < QK_K / 128; ++j) { + // load Q2 + vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); + + vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); + vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl); + vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl); + vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl); + + // duplicate scale elements for product + vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl); + vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl); + vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl); + vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl); + + vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); + vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); + vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); + vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); + + // load Q8 + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl); + vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl); + vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl); + + vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); + vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); + vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); + vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); + + isum += __riscv_vmv_x_s_i32m1_i32(isum1); + + q2 += 32; + q8 += 128; + is = 8; + } + + sumf += dall * isum; + } + break; + case 128: + for (int i = 0; i < nb; ++i) { + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + uint8_t *patmp = atmp; + int vsums; + int tmp; + __asm__ __volatile__( + "vsetivli zero, 16, e8, m1\n\t" + "vmv.v.x v8, zero\n\t" + "vle8.v v1, (%[sc])\n\t" + "vand.vi v0, v1, 0xF\n\t" + "vsrl.vi v1, v1, 4\n\t" + "vse8.v v0, (%[scale])\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vle16.v v2, (%[bsums])\n\t" + "vzext.vf2 v0, v1\n\t" + "vwmul.vv v4, v0, v2\n\t" + "vsetivli zero, 16, e32, m4\n\t" + "vredsum.vs v8, v4, v8\n\t" + "vmv.x.s %[vsums], v8" + : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums) + : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + sumf += dmin * vsums; + int isum = 0; + + for (int j = 0; j < QK_K/128; ++j) { + __asm__ __volatile__( + "vsetvli zero, %[vl32], e8, m2\n\t" + "vle8.v v0, (%[q2])\n\t" + "vsrl.vi v2, v0, 2\n\t" + "vsrl.vi v4, v0, 4\n\t" + "vsrl.vi v6, v0, 6\n\t" + "vand.vi v0, v0, 0x3\n\t" + "vand.vi v2, v2, 0x3\n\t" + "vand.vi v4, v4, 0x3\n\t" + "vsetvli zero, %[vl128], e8, m8\n\t" + "vle8.v v8, (%[q8])\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vmv.v.x v0, zero\n\t" + "vwredsum.vs v10, v16, v0\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "vwredsum.vs v8, v20, v0\n\t" + "vwredsum.vs v7, v22, v0\n\t" + "vwredsum.vs v11, v24, v0\n\t" + "vwredsum.vs v12, v26, v0\n\t" + "vwredsum.vs v13, v28, v0\n\t" + "vwredsum.vs v14, v30, v0\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vslideup.vi v10, v9, 1\n\t" + "vslideup.vi v8, v7, 1\n\t" + "vslideup.vi v11, v12, 1\n\t" + "vslideup.vi v13, v14, 1\n\t" + "vslideup.vi v10, v8, 2\n\t" + "vslideup.vi v11, v13, 2\n\t" + "vsetivli zero, 8, e32, m2\n\t" + "vle8.v v15, (%[scale])\n\t" + "vzext.vf4 v12, v15\n\t" + "vmul.vv v10, v10, v12\n\t" + "vredsum.vs v0, v10, v0\n\t" + "vmv.x.s %[tmp], v0\n\t" + "add %[isum], %[isum], %[tmp]" + : [tmp] "=&r" (tmp), [isum] "+&r" (isum) + : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8) + , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q2 += 32; q8 += 128; patmp += 8; + } + + sumf += dall * isum; + } + break; + default: + assert(false && "Unsupported vector length"); + break; + } + + *s = sumf; + +#else + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + int summs = 0; + for (int j = 0; j < 16; ++j) { + summs += y[i].bsums[j] * (sc[j] >> 4); + } + + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + int isum = 0; + int is = 0; + int d; + for (int k = 0; k < QK_K/128; ++k) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + d = sc[is++] & 0xF; + int isuml = 0; + for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + d = sc[is++] & 0xF; + isuml = 0; + for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + shift += 2; + q8 += 32; + } + q2 += 32; + } + sumf += dall * isum - dmin * summs; + } + *s = sumf; +#endif +} + +void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined __riscv_xtheadvector + + uint32_t utmp[4]; + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + int8_t * scale = (int8_t *)utmp; + int tmp; + __asm__ __volatile__( + "li %[tmp], 12\n\t" + "th.vsetvli zero, %[tmp], e8, m1\n\t" + "th.vlb.v v0, (%[s6b])\n\t" + "th.vmv.v.v v2, v0\n\t" + "li %[tmp], 2\n\t" + "th.vsetvli zero, %[tmp], e64, m1\n\t" + "th.vmv.v.x v9, %[sh]\n\t"\ + "th.vslidedown.vi v1, v0, 1\n\t" + "th.vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4} + "th.vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]} + "li %[tmp], 4\n\t" + "th.vsetvli zero, %[tmp], e32, m1\n\t" + "th.vid.v v9\n\t" + "th.vmv.x.s %[tmp], v1\n\t" + "th.vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6} + "th.vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]} + "th.vsrl.vv v4, v1, v9\n\t" + "th.vsrl.vv v2, v0, v8\n\t" + "th.vand.vx v5, v4, %[kmask1]\n\t" + "th.vand.vx v3, v2, %[kmask2]\n\t" + "th.vsll.vi v6, v5, 4\n\t" + "th.vor.vv v7, v6, v3\n\t" + "li %[tmp], 16\n\t" + "th.vsetvli zero, %[tmp], e8, m1\n\t" + "th.vsub.vx v0, v7, %[c]\n\t" + "th.vsb.v v0, (%[scale])" + : [tmp] "=&r" (tmp) + : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32) + , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + + uint8_t m = 1; + int isum = 0; + for (int j = 0; j < QK_K; j += 128) { + __asm__ __volatile__( + // fixme: use v0p7 mask layout directly + "th.vsetvli zero, %[vl32], e8, m2\n\t" + "th.vlb.v v8, (%[q3])\n\t" + "th.vsrl.vi v10, v8, 2\n\t" + "th.vsrl.vi v12, v8, 4\n\t" + "th.vsrl.vi v14, v8, 6\n\t" + "th.vand.vi v8, v8, 3\n\t" + "th.vand.vi v10, v10, 3\n\t" + "th.vand.vi v12, v12, 3\n\t" + "th.vlb.v v2, (%[qh])\n\t" + "th.vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "th.vmseq.vx v0, v4, zero\n\t" + "th.vadd.vi v8, v8, -4, v0.t\n\t" + "th.vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "th.vmseq.vx v0, v4, zero\n\t" + "th.vadd.vi v10, v10, -4, v0.t\n\t" + "th.vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "th.vmseq.vx v0, v4, zero\n\t" + "th.vadd.vi v12, v12, -4, v0.t\n\t" + "th.vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "th.vmseq.vx v0, v4, zero\n\t" + "th.vadd.vi v14, v14, -4, v0.t\n\t" + "th.vsetvli zero, %[vl128], e8, m8\n\t" + "th.vlb.v v0, (%[q8])\n\t" + "th.vsetvli zero, %[vl64], e8, m4\n\t" + "th.vwmul.vv v16, v0, v8\n\t" + "th.vwmul.vv v24, v4, v12\n\t" + "li %[tmp], 16\n\t" + "th.vsetvli zero, %[tmp], e16, m2\n\t" + "th.vmv.v.x v0, zero\n\t" + "th.vwredsum.vs v10, v16, v0\n\t" + "th.vwredsum.vs v9, v18, v0\n\t" + "th.vwredsum.vs v8, v20, v0\n\t" + "th.vwredsum.vs v7, v22, v0\n\t" + "th.vwredsum.vs v11, v24, v0\n\t" + "th.vwredsum.vs v12, v26, v0\n\t" + "th.vwredsum.vs v13, v28, v0\n\t" + "th.vwredsum.vs v14, v30, v0\n\t" + "li %[tmp], 4\n\t" + "th.vsetvli zero, %[tmp], e32, m1\n\t" + "th.vslideup.vi v10, v9, 1\n\t" + "th.vslideup.vi v8, v7, 1\n\t" + "th.vslideup.vi v11, v12, 1\n\t" + "th.vslideup.vi v13, v14, 1\n\t" + "th.vslideup.vi v10, v8, 2\n\t" + "th.vslideup.vi v11, v13, 2\n\t" + "li %[tmp], 8\n\t" + "th.vsetvli zero, %[tmp], e32, m2\n\t" + "th.vlb.v v12, (%[scale])\n\t" + "th.vmul.vv v10, v10, v12\n\t" + "th.vredsum.vs v0, v10, v0\n\t" + "th.vmv.x.s %[tmp], v0\n\t" + "add %[isum], %[isum], %[tmp]" + : [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum) + : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32) + , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q3 += 32; q8 += 128; scale += 8; + } + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + sumf += d * isum; + } + + *s = sumf; + +#elif defined __riscv_v + + uint32_t utmp[4]; + float sumf = 0; + uint32_t aux[3]; + const int vector_length = __riscv_vlenb() * 8; + + switch (vector_length) { + case 256: + for (int i = 0; i < nb; ++i) { + + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; + + + size_t vl = 32; + uint8_t m = 1; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); + + int sum_t = 0; + + for (int j = 0; j < QK_K; j += 128) { + + vl = 32; + + // load Q3 + vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); + + vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); + vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); + vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); + vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); + + // compute mask for subtraction + vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); + vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); + vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); + vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); + vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl); + m <<= 1; + + // load Q8 and take product with Q3 + vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + + vl = 16; + + // retrieve lane to multiply with scale + vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); + vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); + vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); + vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); + vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); + vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); + vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); + vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + + q3 += 32; q8 += 128; scale += 8; + + } + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + + sumf += d*sum_t; + + } + break; + case 128: + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + int8_t * scale = (int8_t *)utmp; + int tmp; + __asm__ __volatile__( + "vsetivli zero, 12, e8, m1\n\t" + "vle8.v v0, (%[s6b])\n\t" + "vmv1r.v v2, v0\n\t" + "vsetivli zero, 2, e64, m1\n\t" + "vmv.v.x v9, %[sh]\n\t"\ + "vslidedown.vi v1, v0, 1\n\t" + "vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4} + "vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]} + "vsetivli zero, 4, e32, m1\n\t" + "vid.v v9\n\t" + "vmv.x.s %[tmp], v1\n\t" + "vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6} + "vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]} + "vsrl.vv v4, v1, v9\n\t" + "vsrl.vv v2, v0, v8\n\t" + "vand.vx v5, v4, %[kmask1]\n\t" + "vand.vx v3, v2, %[kmask2]\n\t" + "vsll.vi v6, v5, 4\n\t" + "vor.vv v7, v6, v3\n\t" + "vsetivli zero, 16, e8, m1\n\t" + "vsub.vx v0, v7, %[c]\n\t" + "vse8.v v0, (%[scale])" + : [tmp] "=&r" (tmp) + : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32) + , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + + uint8_t m = 1; + int isum = 0; + for (int j = 0; j < QK_K; j += 128) { + __asm__ __volatile__( + "vsetvli zero, %[vl32], e8, m2, ta, mu\n\t" + "vle8.v v8, (%[q3])\n\t" + "vsrl.vi v10, v8, 2\n\t" + "vsrl.vi v12, v8, 4\n\t" + "vsrl.vi v14, v8, 6\n\t" + "vand.vi v8, v8, 3\n\t" + "vand.vi v10, v10, 3\n\t" + "vand.vi v12, v12, 3\n\t" + "vle8.v v2, (%[qh])\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v8, v8, -4, v0.t\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v10, v10, -4, v0.t\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v12, v12, -4, v0.t\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v14, v14, -4, v0.t\n\t" + "vsetvli zero, %[vl128], e8, m8\n\t" + "vle8.v v0, (%[q8])\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vmv.v.x v0, zero\n\t" + "vwredsum.vs v10, v16, v0\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "vwredsum.vs v8, v20, v0\n\t" + "vwredsum.vs v7, v22, v0\n\t" + "vwredsum.vs v11, v24, v0\n\t" + "vwredsum.vs v12, v26, v0\n\t" + "vwredsum.vs v13, v28, v0\n\t" + "vwredsum.vs v14, v30, v0\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vslideup.vi v10, v9, 1\n\t" + "vslideup.vi v8, v7, 1\n\t" + "vslideup.vi v11, v12, 1\n\t" + "vslideup.vi v13, v14, 1\n\t" + "vslideup.vi v10, v8, 2\n\t" + "vslideup.vi v11, v13, 2\n\t" + "vsetivli zero, 8, e32, m2\n\t" + "vle8.v v15, (%[scale])\n\t" + "vsext.vf4 v12, v15\n\t" + "vmul.vv v10, v10, v12\n\t" + "vredsum.vs v0, v10, v0\n\t" + "vmv.x.s %[tmp], v0\n\t" + "add %[isum], %[isum], %[tmp]" + : [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum) + : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32) + , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q3 += 32; q8 += 128; scale += 8; + } + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + sumf += d * isum; + } + break; + default: + assert(false && "Unsupported vector length"); + break; + } + + *s = sumf; + +#else + // scalar version + // This function is written like this so the compiler can manage to vectorize most of it + // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the + // manually vectorized version above. Every other version I tried would run at least 4 times slower. + // The ideal situation would be if we could just write the code once, and the compiler would + // automatically produce the best possible set of machine instructions, instead of us having to manually + // write vectorized versions for AVX, ARM_NEON, etc. + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + uint32_t auxs[4]; + const int8_t * scales = (const int8_t*)auxs; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT hm = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + q3 += 32; + } + a = aux8; + + memcpy(auxs, x[i].scales, 12); + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; + +#endif + +} + +void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#if defined __riscv_xtheadvector + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + int tmp, tmp2, sumi; + __asm__ __volatile__( + "li %[t1], 12\n\t" + "th.vsetvli zero, %[t1], e8, m1\n\t" + "th.vlb.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]} + "li %[t1], 4\n\t" + "th.vsetvli zero, %[t1], e32, m1\n\t" + "th.vslidedown.vi v2, v1, 2\n\t" + "th.vmv.v.v v3, v2\n\t" + "th.vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]} + "li %[t1], 2\n\t" + "th.vsetvli zero, %[t1], e32, m1\n\t" + "th.vmv.v.i v4, 4\n\t" + "th.vand.vx v8, v1, %[kmask1]\n\t" + "th.vslide1up.vx v5, v4, zero\n\t" // {0, 4} + "th.vsrl.vi v6, v1, 6\n\t" + "th.vsrl.vv v7, v2, v5\n\t" + "th.vand.vx v0, v6, %[kmask3]\n\t" + "th.vand.vx v2, v7, %[kmask2]\n\t" + "th.vsll.vi v6, v0, 4\n\t" + "li %[t2], 8\n\t" + "addi %[t1], %[utmp], 4\n\t" + "th.vor.vv v1, v6, v2\n\t" + "th.vssw.v v8, (%[utmp]), %[t2]\n\t" + "th.vssw.v v1, (%[t1]), %[t2]\n\t" + "th.vsetvli zero, zero, e32, m2\n\t" // vl == 8 + "th.vlw.v v2, (%[bsums])\n\t" + "th.vsetvli zero, %[t2], e16, m1\n\t" + "th.vnsrl.vi v0, v2, 0\n\t" + "th.vnsrl.vi v1, v2, 16\n\t" + "th.vadd.vv v2, v0, v1\n\t" + "th.vlbu.v v4, (%[mins])\n\t" + "th.vwmul.vv v6, v4, v2\n\t" + "th.vmv.v.x v0, zero\n\t" + "th.vsetvli zero, %[t2], e32, m2\n\t" + "th.vredsum.vs v0, v6, v0\n\t" + "th.vmv.x.s %[sumi], v0" + : [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi) + : [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp) + , [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1) + , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + sumf -= dmin * sumi; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + sumi = 0; + const uint8_t * scale = scales; + + for (int j = 0; j < QK_K/128; ++j) { + int vl128 = 128, vl64 = 64, vl32 = 32; + __asm__ __volatile__( + "th.vsetvli zero, %[vl128], e8, m8\n\t" + "th.vlb.v v8, (%[q8])\n\t" + "th.vsetvli zero, %[vl64], e8, m4\n\t" + "th.vlb.v v0, (%[q4])\n\t" + "th.vsrl.vi v4, v0, 4\n\t" + "th.vand.vi v0, v0, 0xF\n\t" + "th.vsetvli zero, %[vl32], e8, m2\n\t" + "th.vwmul.vv v28, v6, v14\n\t" + "th.vwmul.vv v20, v4, v10\n\t" + "th.vwmul.vv v24, v2, v12\n\t" + "th.vwmul.vv v16, v0, v8\n\t" + "li %[tmp], 4\n\t" + "th.vsetvli zero, %[tmp], e32, m1\n\t" + "th.vlbu.v v1, (%[scale])\n\t" + "th.vmv.v.x v0, zero\n\t" + "th.vsetvli zero, %[vl32], e16, m4\n\t" + "th.vwredsum.vs v6, v24, v0\n\t" + "th.vwredsum.vs v7, v28, v0\n\t" + "th.vwredsum.vs v4, v16, v0\n\t" + "th.vwredsum.vs v5, v20, v0\n\t" + "th.vsetvli zero, %[tmp], e32, m1\n\t" + "th.vslideup.vi v6, v7, 1\n\t" + "th.vslideup.vi v4, v5, 1\n\t" + "th.vslideup.vi v4, v6, 2\n\t" + "th.vmul.vv v8, v4, v1\n\t" + "th.vredsum.vs v0, v8, v0\n\t" + "th.vmv.x.s %[tmp], v0\n\t" + "add %[sumi], %[sumi], %[tmp]" + : [tmp] "=&r" (tmp), [sumi] "+&r" (sumi) + : [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32) + , [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + + q4 += 64; q8 += 128; scale += 4; + } + + sumf += d * sumi; + + } + + *s = sumf; + +#elif defined __riscv_v + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + float sumf = 0; + const int vector_length = __riscv_vlenb() * 8; + + switch (vector_length) { + case 256: + for (int i = 0; i < nb; ++i) { + + size_t vl = 8; + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); + vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); + vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); + vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); + vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + + vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + vl = 32; + + int32_t sum_1 = 0; + int32_t sum_2 = 0; + + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + + for (int j = 0; j < QK_K/64; ++j) { + // load Q4 + vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); + + // load Q8 and multiply it with lower Q4 nibble + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); + vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); + vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); + + sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; + + // load Q8 and multiply it with upper Q4 nibble + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); + vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); + vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); + vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); + + sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; + + q4 += 32; q8 += 64; + + } + + sumf += d*(sum_1 + sum_2); + + } + break; + case 128: + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + int tmp, tmp2, sumi; + __asm__ __volatile__( + "vsetivli zero, 12, e8, m1\n\t" + "vle8.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]} + "vsetivli zero, 4, e32, m1\n\t" + "vslidedown.vi v2, v1, 2\n\t" + "vmv1r.v v3, v2\n\t" + "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]} + "vsetivli zero, 2, e32, m1\n\t" + "vmv.v.i v4, 4\n\t" + "vand.vx v8, v1, %[kmask1]\n\t" + "vslide1up.vx v5, v4, zero\n\t" // {0, 4} + "vsrl.vi v6, v1, 6\n\t" + "vsrl.vv v7, v2, v5\n\t" + "vand.vx v0, v6, %[kmask3]\n\t" + "vand.vx v2, v7, %[kmask2]\n\t" + "vsll.vi v6, v0, 4\n\t" + "li %[t2], 8\n\t" + "addi %[t1], %[utmp], 4\n\t" + "vor.vv v1, v6, v2\n\t" + "vsse32.v v8, (%[utmp]), %[t2]\n\t" + "vsse32.v v1, (%[t1]), %[t2]\n\t" + "vsetivli zero, 8, e16, m1\n\t" + "vle32.v v2, (%[bsums])\n\t" + "vnsrl.wi v0, v2, 0\n\t" + "vnsrl.wi v1, v2, 16\n\t" + "vadd.vv v2, v0, v1\n\t" + "vle8.v v3, (%[mins])\n\t" + "vzext.vf2 v4, v3\n\t" + "vwmul.vv v6, v4, v2\n\t" + "vmv.v.x v0, zero\n\t" + "vsetivli zero, 8, e32, m2\n\t" + "vredsum.vs v0, v6, v0\n\t" + "vmv.x.s %[sumi], v0" + : [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi) + : [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp) + , [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1) + , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + sumf -= dmin * sumi; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + sumi = 0; + const uint8_t * scale = scales; + + for (int j = 0; j < QK_K/128; ++j) { + int vl128 = 128, vl64 = 64, vl32 = 32; + __asm__ __volatile__( + "vsetvli zero, %[vl128], e8, m8\n\t" + "vle8.v v8, (%[q8])\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vle8.v v0, (%[q4])\n\t" + "vsrl.vi v4, v0, 4\n\t" + "vand.vi v0, v0, 0xF\n\t" + "vsetvli zero, %[vl32], e8, m2\n\t" + "vwmul.vv v28, v6, v14\n\t" + "vwmul.vv v20, v4, v10\n\t" + "vwmul.vv v24, v2, v12\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vle8.v v2, (%[scale])\n\t" + "vmv.v.x v0, zero\n\t" + "vzext.vf4 v1, v2\n\t" + "vsetvli zero, %[vl32], e16, m4\n\t" + "vwredsum.vs v6, v24, v0\n\t" + "vwredsum.vs v7, v28, v0\n\t" + "vwredsum.vs v4, v16, v0\n\t" + "vwredsum.vs v5, v20, v0\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vslideup.vi v6, v7, 1\n\t" + "vslideup.vi v4, v5, 1\n\t" + "vslideup.vi v4, v6, 2\n\t" + "vmul.vv v8, v4, v1\n\t" + "vredsum.vs v0, v8, v0\n\t" + "vmv.x.s %[tmp], v0\n\t" + "add %[sumi], %[sumi], %[tmp]" + : [tmp] "=&r" (tmp), [sumi] "+&r" (sumi) + : [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32) + , [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + + q4 += 64; q8 += 128; scale += 4; + } + + sumf += d * sumi; + } + break; + default: + assert(false && "Unsupported vector length"); + break; + } + + *s = sumf; + +#else + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + a += 32; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + a += 32; q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#if defined __riscv_v + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + float sumf = 0; + float sums = 0.0; + + size_t vl; + + for (int i = 0; i < nb; ++i) { + + vl = 8; + + const uint8_t * GGML_RESTRICT q5 = x[i].qs; + const uint8_t * GGML_RESTRICT hm = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + + vint16m1_t q8sums_0 = __riscv_vlse16_v_i16m1(y[i].bsums, 4, vl); + vint16m1_t q8sums_1 = __riscv_vlse16_v_i16m1(y[i].bsums+1, 4, vl); + vint16m1_t q8sums = __riscv_vadd_vv_i16m1(q8sums_0, q8sums_1, vl); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + vuint8mf2_t mins8 = __riscv_vle8_v_u8mf2(mins, vl); + vint16m1_t v_mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); + vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, v_mins, vl); + + vint32m1_t sumi = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + + vl = 32; + int32_t aux32 = 0; + int is = 0; + + uint8_t m = 1; + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m2_t vqh = __riscv_vle8_v_u8m2(hm, vl); + + for (int j = 0; j < QK_K/64; ++j) { + // load Q5 and Q8 + vuint8m2_t q5_x = __riscv_vle8_v_u8m2(q5, vl); + vint8m2_t q8_y1 = __riscv_vle8_v_i8m2(q8, vl); + vint8m2_t q8_y2 = __riscv_vle8_v_i8m2(q8+32, vl); + + // compute mask for addition + vint8m2_t q5_a = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vand_vx_u8m2(q5_x, 0x0F, vl)); + vuint8m2_t qh_m1 = __riscv_vand_vx_u8m2(vqh, m, vl); + vbool4_t vmask_1 = __riscv_vmsne_vx_u8m2_b4(qh_m1, 0, vl); + vint8m2_t q5_m1 = __riscv_vadd_vx_i8m2_mu(vmask_1, q5_a, q5_a, 16, vl); + m <<= 1; + + vint8m2_t q5_l = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vsrl_vx_u8m2(q5_x, 0x04, vl)); + vuint8m2_t qh_m2 = __riscv_vand_vx_u8m2(vqh, m, vl); + vbool4_t vmask_2 = __riscv_vmsne_vx_u8m2_b4(qh_m2, 0, vl); + vint8m2_t q5_m2 = __riscv_vadd_vx_i8m2_mu(vmask_2, q5_l, q5_l, 16, vl); + m <<= 1; + + vint16m4_t v0 = __riscv_vwmul_vv_i16m4(q5_m1, q8_y1, vl); + vint16m4_t v1 = __riscv_vwmul_vv_i16m4(q5_m2, q8_y2, vl); + + vint32m8_t vs1 = __riscv_vwmul_vx_i32m8(v0, scales[is++], vl); + vint32m8_t vs2 = __riscv_vwmul_vx_i32m8(v1, scales[is++], vl); + + vint32m1_t vacc1 = __riscv_vredsum_vs_i32m8_i32m1(vs1, vzero, vl); + vint32m1_t vacc2 = __riscv_vredsum_vs_i32m8_i32m1(vs2, vacc1, vl); + + aux32 += __riscv_vmv_x_s_i32m1_i32(vacc2); + q5 += 32; q8 += 64; + + } + + sums += aux32 * d; + + } + + *s = sumf+sums; + +#else + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const uint8_t * GGML_RESTRICT hm = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined __riscv_xtheadvector + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + int sum_t = 0; + int t0; + + for (int j = 0; j < QK_K/128; ++j) { + __asm__ __volatile__( + "th.vsetvli zero, %[vl32], e8, m2\n\t" // vl == 32 + "th.vlb.v v4, (%[qh])\n\t" + "th.vsll.vi v0, v4, 4\n\t" + "th.vsll.vi v2, v4, 2\n\t" + "th.vsrl.vi v6, v4, 2\n\t" + "th.vsetvli zero, %[vl64], e8, m4\n\t" // vl == 64 + "th.vlb.v v8, (%[q6])\n\t" + "th.vsrl.vi v12, v8, 4\n\t" + "th.vand.vi v8, v8, 0xF\n\t" + "th.vsetvli zero, %[vl128], e8, m8\n\t" // vl == 128 + "th.vand.vx v0, v0, %[mask]\n\t" + "th.vor.vv v8, v8, v0\n\t" + "th.vlb.v v0, (%[q8])\n\t" + "th.vsub.vx v8, v8, %[vl32]\n\t" + "th.vsetvli zero, %[vl64], e8, m4\n\t" // vl == 64 + "th.vwmul.vv v16, v0, v8\n\t" + "th.vwmul.vv v24, v4, v12\n\t" + "li %[t0], 16\n\t" + "th.vsetvli zero, %[t0], e16, m2\n\t" // vl == 16 + "th.vmv.v.x v0, zero\n\t" + "th.vwredsum.vs v10, v16, v0\n\t" + "th.vwredsum.vs v9, v18, v0\n\t" + "th.vwredsum.vs v8, v20, v0\n\t" + "th.vwredsum.vs v7, v22, v0\n\t" + "th.vwredsum.vs v11, v24, v0\n\t" + "th.vwredsum.vs v12, v26, v0\n\t" + "th.vwredsum.vs v13, v28, v0\n\t" + "th.vwredsum.vs v14, v30, v0\n\t" + "li %[t0], 4\n\t" + "th.vsetvli zero, %[t0], e32, m1\n\t" // vl == 4 + "th.vslideup.vi v10, v9, 1\n\t" + "th.vslideup.vi v8, v7, 1\n\t" + "th.vslideup.vi v11, v12, 1\n\t" + "th.vslideup.vi v13, v14, 1\n\t" + "th.vslideup.vi v10, v8, 2\n\t" + "th.vslideup.vi v11, v13, 2\n\t" + "li %[t0], 8\n\t" + "th.vsetvli zero, %[t0], e32, m2\n\t" // vl == 8 + "th.vlb.v v4, (%[scale])\n\t" + "th.vmul.vv v2, v4, v10\n\t" + "th.vredsum.vs v0, v2, v0\n\t" + "th.vmv.x.s %[t0], v0\n\t" + "add %[sumi], %[sumi], %[t0]" + : [sumi] "+&r" (sum_t), [t0] "=&r" (t0) + : [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale) + , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) + , [mask] "r" (0x30) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q6 += 64; qh += 32; q8 += 128; scale += 8; + } + + sumf += d * sum_t; + + } + + *s = sumf; + +#elif defined __riscv_v + + float sumf = 0; + const int vector_length = __riscv_vlenb() * 8; + + switch (vector_length) { + case 256: + for (int i = 0; i < nb; ++i) { + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q6 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const int8_t * GGML_RESTRICT scale = x[i].scales; + + size_t vl; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + + int sum_t = 0; + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + vl = 32; + + // load qh + vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); + + // load Q6 + vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); + vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); + + vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); + vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); + vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); + vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); + + vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); + vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); + vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); + vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); + + vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); + vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); + vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); + vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); + + vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); + vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); + vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); + vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); + + // load Q8 and take product + vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + + vl = 16; + + vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); + vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); + vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); + vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); + vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); + vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); + vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); + vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + + q6 += 64; qh += 32; q8 += 128; is=8; + + } + + sumf += d * sum_t; + + } + break; + case 128: + for (int i = 0; i < nb; ++i) { + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + int sum_t = 0; + int t0; + + for (int j = 0; j < QK_K/128; ++j) { + __asm__ __volatile__( + "vsetvli zero, %[vl32], e8, m2\n\t" + "vle8.v v4, (%[qh])\n\t" + "vsll.vi v0, v4, 4\n\t" + "vsll.vi v2, v4, 2\n\t" + "vsrl.vi v6, v4, 2\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vle8.v v8, (%[q6])\n\t" + "vsrl.vi v12, v8, 4\n\t" + "vand.vi v8, v8, 0xF\n\t" + "vsetvli zero, %[vl128], e8, m8\n\t" + "vand.vx v0, v0, %[mask]\n\t" + "vor.vv v8, v8, v0\n\t" + "vle8.v v0, (%[q8])\n\t" + "vsub.vx v8, v8, %[vl32]\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vmv.v.x v0, zero\n\t" + "vwredsum.vs v10, v16, v0\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "vwredsum.vs v8, v20, v0\n\t" + "vwredsum.vs v7, v22, v0\n\t" + "vwredsum.vs v11, v24, v0\n\t" + "vwredsum.vs v12, v26, v0\n\t" + "vwredsum.vs v13, v28, v0\n\t" + "vwredsum.vs v14, v30, v0\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vslideup.vi v10, v9, 1\n\t" + "vslideup.vi v8, v7, 1\n\t" + "vslideup.vi v11, v12, 1\n\t" + "vslideup.vi v13, v14, 1\n\t" + "vslideup.vi v10, v8, 2\n\t" + "vslideup.vi v11, v13, 2\n\t" + "vsetivli zero, 8, e32, m2\n\t" + "vle8.v v2, (%[scale])\n\t" + "vsext.vf4 v4, v2\n\t" + "vmul.vv v2, v4, v10\n\t" + "vredsum.vs v0, v2, v0\n\t" + "vmv.x.s %[t0], v0\n\t" + "add %[sumi], %[sumi], %[t0]" + : [sumi] "+&r" (sum_t), [t0] "=&r" (t0) + : [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale) + , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) + , [mask] "r" (0x30) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q6 += 64; qh += 32; q8 += 128; scale += 8; + } + + sumf += d * sum_t; + + } + break; + default: + assert(false && "Unsupported vector length"); + break; + } + + *s = sumf; + +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q4 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + a += 128; + q4 += 64; + qh += 32; + } + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/16; ++j) { + int scale = x[i].scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp new file mode 100644 index 000000000..0882b4102 --- /dev/null +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -0,0 +1,396 @@ +#define GGML_COMMON_IMPL_CPP +#define GGML_COMMON_DECL_CPP +#include "ggml-common.h" +#include "ggml-backend-impl.h" + +#include "ggml-impl.h" +#include "ggml-cpu.h" +#include "ggml-cpu-impl.h" +#include "traits.h" + +#include +#include +#include +#include // for qsort +#include // for GGML_ASSERT + +#define GGML_CPU_CLANG_WORKAROUND +#include "../../repack.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Woverlength-strings" +#endif + +#define UNUSED GGML_UNUSED + +void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v + if (__riscv_vlenb() >= QK4_0) { + const size_t vl = QK4_0; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + + vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + for (int l = 0; l < nb; l++) { + const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0]; + const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8]; + const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16]; + const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment constraints + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, vl / 4)); + + const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4); + const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4); + const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4); + const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0); + const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1); + const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0); + const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1); + + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_hi_m)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + // vector version needs Zvfhmin extension + const float a_scale = GGML_FP16_TO_FP32(a_ptr[l].d); + const float b_scales[8] = { + GGML_FP16_TO_FP32(b_ptr[l].d[0]), + GGML_FP16_TO_FP32(b_ptr[l].d[1]), + GGML_FP16_TO_FP32(b_ptr[l].d[2]), + GGML_FP16_TO_FP32(b_ptr[l].d[3]), + GGML_FP16_TO_FP32(b_ptr[l].d[4]), + GGML_FP16_TO_FP32(b_ptr[l].d[5]), + GGML_FP16_TO_FP32(b_ptr[l].d[6]), + GGML_FP16_TO_FP32(b_ptr[l].d[7]) + }; + const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4); + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scale, vl / 4); + sumf = __riscv_vfmacc_vv_f32m1(sumf, tmp1, b_scales_vec, vl / 4); + } + __riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, vl / 4); + } + return; + } + +#endif + { + float sumf[8]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } + } +} + +void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v + if (__riscv_vlenb() >= QK4_0) { + const size_t vl = QK4_0; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); + for (int l = 0; l < nb; l++) { + const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4); + const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4); + const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4); + const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0); + const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1); + const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0); + const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1); + + // vector version needs Zvfhmin extension + const float a_scales[4] = { + GGML_FP16_TO_FP32(a_ptr[l].d[0]), + GGML_FP16_TO_FP32(a_ptr[l].d[1]), + GGML_FP16_TO_FP32(a_ptr[l].d[2]), + GGML_FP16_TO_FP32(a_ptr[l].d[3]) + }; + const float b_scales[8] = { + GGML_FP16_TO_FP32(b_ptr[l].d[0]), + GGML_FP16_TO_FP32(b_ptr[l].d[1]), + GGML_FP16_TO_FP32(b_ptr[l].d[2]), + GGML_FP16_TO_FP32(b_ptr[l].d[3]), + GGML_FP16_TO_FP32(b_ptr[l].d[4]), + GGML_FP16_TO_FP32(b_ptr[l].d[5]), + GGML_FP16_TO_FP32(b_ptr[l].d[6]), + GGML_FP16_TO_FP32(b_ptr[l].d[7]) + }; + const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4); + + const int64_t A0 = *(const int64_t *)&a_ptr[l].qs[0]; + const int64_t A4 = *(const int64_t *)&a_ptr[l].qs[32]; + const int64_t A8 = *(const int64_t *)&a_ptr[l].qs[64]; + const int64_t Ac = *(const int64_t *)&a_ptr[l].qs[96]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + vint16m4_t sumi_l0; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A0, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A4, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A8, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ac, vl / 4)); + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l0 = sumi_hi_m; + } + + { + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l0)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4); + sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, tmp1, b_scales_vec, vl / 4); + } + + const int64_t A1 = *(const int64_t *)&a_ptr[l].qs[8]; + const int64_t A5 = *(const int64_t *)&a_ptr[l].qs[40]; + const int64_t A9 = *(const int64_t *)&a_ptr[l].qs[72]; + const int64_t Ad = *(const int64_t *)&a_ptr[l].qs[104]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + vint16m4_t sumi_l1; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A1, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A5, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A9, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ad, vl / 4)); + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l1 = sumi_hi_m; + } + + { + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l1)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[1], vl / 4); + sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, tmp1, b_scales_vec, vl / 4); + } + + const int64_t A2 = *(const int64_t *)&a_ptr[l].qs[16]; + const int64_t A6 = *(const int64_t *)&a_ptr[l].qs[48]; + const int64_t Aa = *(const int64_t *)&a_ptr[l].qs[80]; + const int64_t Ae = *(const int64_t *)&a_ptr[l].qs[112]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + vint16m4_t sumi_l2; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A2, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A6, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Aa, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ae, vl / 4)); + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l2 = sumi_hi_m; + } + + { + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l2)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[2], vl / 4); + sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, tmp1, b_scales_vec, vl / 4); + } + + const int64_t A3 = *(const int64_t *)&a_ptr[l].qs[24]; + const int64_t A7 = *(const int64_t *)&a_ptr[l].qs[56]; + const int64_t Ab = *(const int64_t *)&a_ptr[l].qs[88]; + const int64_t Af = *(const int64_t *)&a_ptr[l].qs[120]; + __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment + vint16m4_t sumi_l3; + { + const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A3, vl / 4)); + const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A7, vl / 4)); + const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ab, vl / 4)); + const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Af, vl / 4)); + const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); + const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); + const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); + const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); + + sumi_l3 = sumi_hi_m; + } + + { + const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l3)); + const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); + const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); + const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); + const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); + const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); + const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); + const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); + const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); + const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); + const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); + const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); + const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); + + const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[3], vl / 4); + sumf3 = __riscv_vfmacc_vv_f32m1(sumf3, tmp1, b_scales_vec, vl / 4); + } + } + __riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, vl / 4); + __riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, vl / 4); + __riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, vl / 4); + __riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, vl / 4); + } + } + + return; + } + +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) + float sumf[4][8]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} diff --git a/ggml/src/ggml-cpu/arch/s390/quants.c b/ggml/src/ggml-cpu/arch/s390/quants.c new file mode 100644 index 000000000..26bd90875 --- /dev/null +++ b/ggml/src/ggml-cpu/arch/s390/quants.c @@ -0,0 +1,1299 @@ +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" +#include "ggml-quants.h" +#include "ggml-impl.h" +#include "ggml-cpu.h" + +#include "../../quants.h" +#include "../../ggml-cpu-impl.h" + +#include +#include +#include +#include +#include // for qsort +#include // for GGML_ASSERT + +#define GROUP_MAX_EPS 1e-15f +#define GROUP_MAX_EPS_IQ3_XXS 1e-8f +#define GROUP_MAX_EPS_IQ2_S 1e-8f +#define GROUP_MAX_EPS_IQ1_M 1e-7f +#define GROUP_MAX_EPS_IQ1_S 1e-12f + +#define UNUSED GGML_UNUSED + +void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0 * GGML_RESTRICT y = vy; + +#if defined(__VXE__) || defined(__VXE2__) + for (int i = 0; i < nb; i++) { + __vector float srcv [8]; + __vector float asrcv[8]; + __vector float amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); + for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(vec_extract(amaxv[0], 0), + vec_extract(amaxv[0], 1)), + MAX(vec_extract(amaxv[0], 2), + vec_extract(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f / d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + for (int j = 0; j < 8; j++) { + const __vector float v = vec_mul(srcv[j], vec_splats(id)); + const __vector int32_t vi = vec_signed(v); + + y[i].qs[4*j + 0] = vec_extract(vi, 0); + y[i].qs[4*j + 1] = vec_extract(vi, 1); + y[i].qs[4*j + 2] = vec_extract(vi, 2); + y[i].qs[4*j + 3] = vec_extract(vi, 3); + } + } +#else + GGML_UNUSED(nb); + // scalar + quantize_row_q8_0_ref(x, y, k); +#endif +} + +void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % QK8_1 == 0); + const int nb = k / QK8_1; + + block_q8_1 * GGML_RESTRICT y = vy; + +#if defined(__VXE__) || defined(__VXE2__) + for (int i = 0; i < nb; i++) { + __vector float srcv [8]; + __vector float asrcv[8]; + __vector float amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); + for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(vec_extract(amaxv[0], 0), + vec_extract(amaxv[0], 1)), + MAX(vec_extract(amaxv[0], 2), + vec_extract(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f / d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + __vector int32_t acc = vec_splats(0); + + for (int j = 0; j < 8; j++) { + const __vector float v = vec_mul(srcv[j], vec_splats(id)); + const __vector int32_t vi = vec_signed(v); + + y[i].qs[4*j + 0] = vec_extract(vi, 0); + y[i].qs[4*j + 1] = vec_extract(vi, 1); + y[i].qs[4*j + 2] = vec_extract(vi, 2); + y[i].qs[4*j + 3] = vec_extract(vi, 3); + + acc = vec_add(acc, vi); + } + + y[i].s = GGML_FP32_TO_FP16(d * (acc[0] + acc[1] + acc[2] + acc[3])); + } +#else + GGML_UNUSED(nb); + // scalar + quantize_row_q8_1_ref(x, y, k); +#endif +} + + +//===================================== Dot products ================================= + +void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + int ib = 0; + float sumf = 0; + +#if defined(__VXE__) || defined(__VXE2__) + __vector float acc = vec_splats(0.0f); + + const __vector uint8_t v_m = vec_splats((const uint8_t)0x0F); + const __vector int8_t v_s = vec_splats( (const int8_t)0x08); + + for (; ib < nb; ++ib) { + const __vector uint8_t v_x = vec_xl(0, x[ib].qs); + const __vector int8_t v_xl = (const __vector int8_t)(v_x & v_m); + const __vector int8_t v_xh = (const __vector int8_t)(v_x >> 4); + + const __vector int8_t v_xls = vec_sub(v_xl, v_s); + const __vector int8_t v_xhs = vec_sub(v_xh, v_s); + + const __vector int8_t v_yl = vec_xl(0 , y[ib].qs); + const __vector int8_t v_yh = vec_xl(QK8_0/2, y[ib].qs); + + const __vector int16_t v_xylso = vec_mulo(v_xls, v_yl); + const __vector int16_t v_xylse = vec_mule(v_xls, v_yl); + const __vector int16_t v_xyhso = vec_mulo(v_xhs, v_yh); + const __vector int16_t v_xyhse = vec_mule(v_xhs, v_yh); + + __vector int16_t v_xy_ = v_xylso + v_xylse + v_xyhso + v_xyhse; v_xy_ += vec_reve(v_xy_); + + const __vector float v_xy = vec_float(vec_unpackh(v_xy_)); + const __vector float v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); + + acc = vec_madd(v_xy, v_d, acc); + } + + sumf = acc[0] + acc[1] + acc[2] + acc[3]; + +#endif + for (; ib < nb; ++ib) { + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[ib].qs[j] & 0x0F) - 8; + const int v1 = (x[ib].qs[j] >> 4) - 8; + + sumi0 += (v0 * y[ib].qs[j]); + sumi1 += (v1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d); + } + + *s = sumf; +} + +void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_1 * GGML_RESTRICT x = vx; + const block_q8_1 * GGML_RESTRICT y = vy; + + int ib = 0; + float sumf = 0; + +#if defined(__VXE__) || defined(__VXE2__) + float summs = 0; + float32x4_t acc = vec_splats(0.0f); + + const uint8x16_t v_m = vec_splat_u8(0x0F); + +#pragma GCC unroll 4 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); + + const uint8x16_t v_x = vec_xl(0, x[ib].qs); + const int8x16_t v_xl = (const int8x16_t)(v_x & v_m); + const int8x16_t v_xh = (const int8x16_t)(v_x >> 4); + + const int8x16_t v_yl = vec_xl(0 , y[ib].qs); + const int8x16_t v_yh = vec_xl(QK8_1/2, y[ib].qs); + + const int32x4_t v_xy_ = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh); + const float32x4_t v_xy = vec_float(v_xy_); + + const float32x4_t v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); + + acc = vec_madd(v_xy, v_d, acc); + } + + sumf = acc[0] + acc[1] + acc[2] + acc[3] + summs; + +#endif + for (; ib < nb; ++ib) { + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[ib].qs[j] & 0x0F); + const int v1 = (x[ib].qs[j] >> 4); + + sumi0 += (v0 * y[ib].qs[j]); + sumi1 += (v1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); + } + + *s = sumf; +} + +void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q8_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + int ib = 0; + float sumf = 0; + +#if defined(__VXE__) || defined(__VXE2__) + __vector float acc = vec_splats(0.0f); + +#pragma GCC unroll 8 + for (; ib < nb; ++ib) { + __builtin_prefetch(x[ib].qs, 0, 1); + __builtin_prefetch(y[ib].qs, 0, 1); + + const int8x16_t v_xl = vec_xl(0 , x[ib].qs); + const int8x16_t v_xh = vec_xl(QK8_0/2, x[ib].qs); + const int8x16_t v_yl = vec_xl(0 , y[ib].qs); + const int8x16_t v_yh = vec_xl(QK8_0/2, y[ib].qs); + + const int32x4_t v_xy_ = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh); + const float32x4_t v_xy = vec_float(v_xy_); + const float32x4_t v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); + + acc = vec_madd(v_xy, v_d, acc); + } + + sumf = acc[0] + acc[1] + acc[2] + acc[3]; + +#endif + for (; ib < nb; ++ib) { + int sumi = 0; + + for (int j = 0; j < qk; j++) { + sumi += x[ib].qs[j]*y[ib].qs[j]; + } + + sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)); + } + + *s = sumf; +} + +void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__VXE__) || defined(__VXE2__) + uint32_t aux[3]; + uint32_t utmp[4]; + + const int32x4_t v_z = vec_splat_s32(0); + const uint8x16_t v_3m = vec_splat_u8(0x03); + + const uint8x16_t v_0c = vec_splat_u8(1); + const uint8x16_t v_1c = vec_sl(v_0c, 1); + const uint8x16_t v_2c = vec_sl(v_0c, 2); + const uint8x16_t v_3c = vec_sl(v_0c, 3); + + uint8x16_t q3h[4]; + uint8x16_t q3b[2]; + int8x16_t q3bytes[4]; + int8x16_t q8bytes[4]; + uint8x16_t qhbits[2]; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict x0l = x[i].qs; + const uint8_t * restrict x0h = x[i].hmask; + const int8_t * restrict y0 = y[i].qs; + + qhbits[0] = vec_xl(0 , x0h); + qhbits[1] = vec_xl(16, x0h); + + int32_t isum = 0; + + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; + + for (int j = 0; j < QK_K/128; ++j) { + int32x4_t isum0, isum1, isum2, isum3; + + q3b[0] = vec_xl(0 , x0l); + q3b[1] = vec_xl(16, x0l); + x0l += 32; + + q8bytes[0] = vec_xl(0 , y0); + q8bytes[1] = vec_xl(16 , y0); + q8bytes[2] = vec_xl(32 , y0); + q8bytes[3] = vec_xl(48 , y0); + q8bytes[4] = vec_xl(64 , y0); + q8bytes[5] = vec_xl(80 , y0); + q8bytes[6] = vec_xl(96 , y0); + q8bytes[7] = vec_xl(112, y0); + y0 += 128; + + q3h[0] = vec_sl(vec_andc(v_0c, qhbits[0]), 2); + q3h[1] = vec_sl(vec_andc(v_0c, qhbits[1]), 2); + q3h[2] = vec_sl(vec_andc(v_1c, qhbits[0]), 1); + q3h[3] = vec_sl(vec_andc(v_1c, qhbits[1]), 1); + + q3bytes[0] = vec_sub((int8x16_t)vec_and(q3b[0], v_3m), (int8x16_t)q3h[0]); + q3bytes[1] = vec_sub((int8x16_t)vec_and(q3b[1], v_3m), (int8x16_t)q3h[1]); + q3bytes[2] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 2), v_3m), (int8x16_t)q3h[2]); + q3bytes[3] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 2), v_3m), (int8x16_t)q3h[3]); + + isum0 = ggml_vec_dot(v_z, q3bytes[0], q8bytes[0]); + isum1 = ggml_vec_dot(v_z, q3bytes[1], q8bytes[1]); + isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[2]); + isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[3]); + + isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0]; + isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1]; + isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2]; + isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3]; + + scale += 4; + + q3h[0] = vec_andc(v_2c, qhbits[0]); + q3h[1] = vec_andc(v_2c, qhbits[1]); + q3h[2] = vec_sr(vec_andc(v_3c, qhbits[0]), 1); + q3h[3] = vec_sr(vec_andc(v_3c, qhbits[1]), 1); + + q3bytes[0] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 4), v_3m), (int8x16_t)q3h[0]); + q3bytes[1] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 4), v_3m), (int8x16_t)q3h[1]); + q3bytes[2] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[0], 6), v_3m), (int8x16_t)q3h[2]); + q3bytes[3] = vec_sub((int8x16_t)vec_and(vec_sr(q3b[1], 6), v_3m), (int8x16_t)q3h[3]); + + isum0 = ggml_vec_dot(v_z, q3bytes[0], q8bytes[4]); + isum1 = ggml_vec_dot(v_z, q3bytes[1], q8bytes[5]); + isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[6]); + isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[7]); + + isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0]; + isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1]; + isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2]; + isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3]; + + scale += 4; + + if (j == 0) { + qhbits[0] = vec_sr(qhbits[0], 4); + qhbits[1] = vec_sr(qhbits[1], 4); + } + } + + sum += d * isum; + } + + *s = sum; + +#else + // scalar version + // This function is written like this so the compiler can manage to vectorize most of it + // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the + // manually vectorized version above. Every other version I tried would run at least 4 times slower. + // The ideal situation would be if we could just write the code once, and the compiler would + // automatically produce the best possible set of machine instructions, instead of us having to manually + // write vectorized versions for AVX, ARM_NEON, etc. + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + uint32_t auxs[4]; + const int8_t * scales = (const int8_t*)auxs; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT hm = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + q3 += 32; + } + a = aux8; + + memcpy(auxs, x[i].scales, 12); + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; + +#endif + +} + +void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#if defined(__VXE__) || defined(__VXE2__) + const uint8x16_t v_lm = vec_splat_u8(0x0F); + const int32x4_t v_z = vec_splat_s32(0); + + uint8x16_t v_x[2]; + int8x16_t v_xl[2]; + int8x16_t v_y[2]; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums); + const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums); + const int16x8_t v_ysums = vec_padd_s16(v_ysumsl, v_ysumsh); + + memcpy(utmp, x[i].scales, 12); + + uint32x4_t v_mins8 = { 0 }; + v_mins8 = vec_insert(utmp[1] & kmask1, v_mins8, 0); + v_mins8 = vec_insert(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), v_mins8, 1); + + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[0] &= kmask1; + + const int16x8_t v_minsh = (int16x8_t)vec_unpackh((uint8x16_t)v_mins8); + + const int32x4_t v_minso = vec_mulo(v_ysums, v_minsh); + const int32x4_t v_minse = vec_mule(v_ysums, v_minsh); + const int32x4_t v_mins = v_minso + v_minse; + sumf -= dmin * (v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]); + + const uint8_t * scales = (const uint8_t *)utmp; + const uint8_t * GGML_RESTRICT x0 = x[i].qs; + const int8_t * GGML_RESTRICT y0 = y[i].qs; + + int32_t sumi1 = 0; + int32_t sumi2 = 0; + + for (int j = 0; j < QK_K/64; ++j) { + v_x[0] = vec_xl(0 , x0); + v_x[1] = vec_xl(16, x0); + x0 += 32; + + v_y[0] = vec_xl(0 , y0); + v_y[1] = vec_xl(16, y0); + y0 += 32; + + v_xl[0] = (int8x16_t)vec_and(v_x[0], v_lm); + v_xl[1] = (int8x16_t)vec_and(v_x[1], v_lm); + + const int32x4_t p1 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]); + sumi1 += (p1[0] + p1[1] + p1[2] + p1[3]) * scales[2*j+0]; + + v_y[0] = vec_xl(0 , y0); + v_y[1] = vec_xl(16, y0); + y0 += 32; + + v_xl[0] = (int8x16_t)vec_sr(v_x[0], 4); + v_xl[1] = (int8x16_t)vec_sr(v_x[1], 4); + + const int32x4_t p2 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]); + sumi2 += (p2[0] + p2[1] + p2[2] + p2[3]) * scales[2*j+1]; + } + + sumf += d * (sumi1 + sumi2); + } + + *s = sumf; + +#else + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + a += 32; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + a += 32; q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#if defined(__VXE__) || defined(__VXE2__) + const uint8x16_t v_lm = vec_splat_u8(0x0F); + const uint8x16_t v_1m = vec_splat_u8(0x01); + const uint8x16_t v_2m = vec_splat_u8(0x02); + + const int32x4_t v_z = vec_splat_s32(0); + + const uchar8x16_t v_minsm = { + 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF + }; + + int8x16_t q5b[4]; + uint8x16_t q5h[4]; + + uint8x16_t v_xl[2]; + uint8x16_t v_xh[2]; + int8x16_t v_y[4]; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums); + const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums); + const int16x8_t v_ysums = vec_padd_s16(v_ysumsl, v_ysumsh); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const uint8x16_t v_mins16 = vec_xl(0, (const uint8_t *)utmp); + const uint8x16_t v_mins8 = vec_perm(v_mins16, v_mins16, v_minsm); + const int16x8_t v_minsh = (int16x8_t)vec_unpackh(v_mins8); + + const int32x4_t v_minsho = vec_mulo(v_ysums, v_minsh); + const int32x4_t v_minshe = vec_mule(v_ysums, v_minsh); + const int32x4_t v_mins = vec_add(v_minsho, v_minshe); + const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]; + + const uint8_t * scales = (const uint8_t *)utmp; + const uint8_t * GGML_RESTRICT x0l = x[i].qs; + const uint8_t * GGML_RESTRICT x0h = x[i].qh; + const int8_t * GGML_RESTRICT y0 = y[i].qs; + + v_xh[0] = vec_xl(0 , x0h); + v_xh[1] = vec_xl(16, x0h); + + int32_t sumi = 0; + for (int j = 0; j < QK_K/64; ++j) { + v_xl[0] = vec_xl(0 , x0l); + v_xl[1] = vec_xl(16, x0l); + x0l += 32; + + v_y[0] = vec_xl(0 , y0); + v_y[1] = vec_xl(16, y0); + v_y[2] = vec_xl(32, y0); + v_y[3] = vec_xl(48, y0); + y0 += 64; + + q5h[0] = vec_sl(vec_and(v_1m, v_xh[0]), 4); + q5h[1] = vec_sl(vec_and(v_1m, v_xh[1]), 4); + q5h[2] = vec_sl(vec_and(v_2m, v_xh[0]), 3); + q5h[3] = vec_sl(vec_and(v_2m, v_xh[1]), 3); + v_xh[0] = vec_sr(v_xh[0], 2); + v_xh[1] = vec_sr(v_xh[1], 2); + + q5b[0] = (int8x16_t)vec_or(vec_and(v_xl[0], v_lm), q5h[0]); + q5b[1] = (int8x16_t)vec_or(vec_and(v_xl[1], v_lm), q5h[1]); + q5b[2] = (int8x16_t)vec_or(vec_sr(v_xl[0], 4), q5h[2]); + q5b[3] = (int8x16_t)vec_or(vec_sr(v_xl[1], 4), q5h[3]); + + int32x4_t sumi0 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[0], v_y[0]), q5b[1], v_y[1]); + int32x4_t sumi1 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[2], v_y[2]), q5b[3], v_y[3]); + + sumi += (sumi0[0] + sumi0[1] + sumi0[2] + sumi0[3]) * *scales++; + sumi += (sumi1[0] + sumi1[1] + sumi1[2] + sumi1[3]) * *scales++; + } + + sumf += d * sumi - dmin * mins; + } + + *s = sumf; + +#else + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const uint8_t * GGML_RESTRICT hm = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__VXE__) || defined(__VXE2__) + float sum = 0; + + // Lower 4-bit and upper 2-bit masks + const uint8x16_t v_lm = vec_splat_u8(0x0F); + const uint8x16_t v_um = vec_splat_u8(0x03); + + const int32x4_t v_z = vec_splat_s32(0); + + int8x16_t q6b[4]; + uint8x16_t q6h[4]; + + uint8x16_t v_xl[4]; + uint8x16_t v_xh[2]; + int8x16_t v_y[4]; + + for (int i = 0; i < nb; ++i) { + const float d_all = GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * GGML_RESTRICT x0l = x[i].ql; + const uint8_t * GGML_RESTRICT x0h = x[i].qh; + const int8_t * GGML_RESTRICT y0 = y[i].qs; + + const int8_t * GGML_RESTRICT scale = x[i].scales; + + const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums); + const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums); + + const int8x16_t v_scale = vec_xl(0, scale); + const int16x8_t v_scalel = vec_unpackh(v_scale); + const int16x8_t v_scaleh = vec_unpackl(v_scale); + + const int32x4_t v_minslo = vec_mulo(v_ysumsl, v_scalel); + const int32x4_t v_minsle = vec_mule(v_ysumsl, v_scalel); + const int32x4_t v_minsho = vec_mulo(v_ysumsh, v_scaleh); + const int32x4_t v_minshe = vec_mule(v_ysumsh, v_scaleh); + const int32x4_t v_mins = v_minslo + v_minsle + v_minsho + v_minshe; + + const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]; + + int32_t isum = 0; + for (int j = 0; j < QK_K/128; ++j) { + // Load model upper 2 bits + v_xh[0] = vec_xl(0 , x0h); + v_xh[1] = vec_xl(16, x0h); + x0h += 32; + + // Load model lower 4 bits + v_xl[0] = vec_xl(0 , x0l); + v_xl[1] = vec_xl(16, x0l); + v_xl[2] = vec_xl(32, x0l); + v_xl[3] = vec_xl(48, x0l); + x0l += 64; + + // Load activation quants + v_y[0] = vec_xl(0 , y0); + v_y[1] = vec_xl(16, y0); + v_y[2] = vec_xl(32, y0); + v_y[3] = vec_xl(48, y0); + y0 += 64; + + q6h[0] = vec_sl(vec_and(v_um, v_xh[0]), 4); + q6h[1] = vec_sl(vec_and(v_um, v_xh[1]), 4); + uint8x16_t shifted = vec_sr(v_xh[0], 2); + q6h[2] = vec_sl(vec_and(v_um, shifted), 4); + shifted = vec_sr(v_xh[1], 2); + q6h[3] = vec_sl(vec_and(v_um, shifted), 4); + + q6b[0] = (int8x16_t)(vec_or(vec_and(v_xl[0], v_lm), q6h[0])); + q6b[1] = (int8x16_t)(vec_or(vec_and(v_xl[1], v_lm), q6h[1])); + q6b[2] = (int8x16_t)(vec_or(vec_and(v_xl[2], v_lm), q6h[2])); + q6b[3] = (int8x16_t)(vec_or(vec_and(v_xl[3], v_lm), q6h[3])); + + int32x4_t summs0 = ggml_vec_dot(v_z, q6b[0], v_y[0]); + int32x4_t summs1 = ggml_vec_dot(v_z, q6b[1], v_y[1]); + int32x4_t summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]); + int32x4_t summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]); + + isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] + + (summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] + + (summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] + + (summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3]; + + scale += 4; + + + // Load activation quants + v_y[0] = vec_xl(0 , y0); + v_y[1] = vec_xl(16, y0); + v_y[2] = vec_xl(32, y0); + v_y[3] = vec_xl(48, y0); + y0 += 64; + + shifted = vec_sr(v_xh[0], 4); + q6h[0] = vec_sl(vec_and(v_um, shifted), 4); + shifted = vec_sr(v_xh[1], 4); + q6h[1] = vec_sl(vec_and(v_um, shifted), 4); + shifted = vec_sr(v_xh[0], 6); + q6h[2] = vec_sl(vec_and(v_um, shifted), 4); + shifted = vec_sr(v_xh[1], 6); + q6h[3] = vec_sl(vec_and(v_um, shifted), 4); + + q6b[0] = (int8x16_t)(vec_or(vec_sr(v_xl[0], 4), q6h[0])); + q6b[1] = (int8x16_t)(vec_or(vec_sr(v_xl[1], 4), q6h[1])); + q6b[2] = (int8x16_t)(vec_or(vec_sr(v_xl[2], 4), q6h[2])); + q6b[3] = (int8x16_t)(vec_or(vec_sr(v_xl[3], 4), q6h[3])); + + summs0 = ggml_vec_dot(v_z, q6b[0], v_y[0]); + summs1 = ggml_vec_dot(v_z, q6b[1], v_y[1]); + summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]); + summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]); + + isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] + + (summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] + + (summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] + + (summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3]; + + scale += 4; + } + + sum += d_all * y[i].d * (isum - 32 * mins); + } + + *s = sum; + +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q4 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + a += 128; + q4 += 64; + qh += 32; + } + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/16; ++j) { + int scale = x[i].scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +// #if defined(__VXE__) || defined(__VXE2__) +// static const int8_t keven_signs_q2xs[1024] = { +// 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, +// 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, +// 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1, +// 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1, +// 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, +// 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1, +// 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1, +// 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, +// 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1, +// 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1, +// 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1, +// 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1, +// 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1, +// 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1, +// 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1, +// 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1, +// 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1, +// 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1, +// 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1, +// 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1, +// 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1, +// 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1, +// 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, +// 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1, +// 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1, +// 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1, +// 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, +// 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1, +// 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1, +// 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1, +// 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1, +// 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, +// }; +// #endif + +// void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +// assert(n % QK_K == 0); +// assert(nrc == 1); +// UNUSED(nrc); +// UNUSED(bx); +// UNUSED(by); +// UNUSED(bs); + +// const block_iq2_xxs * GGML_RESTRICT x = vx; +// const block_q8_K * GGML_RESTRICT y = vy; + +// const int nb = n / QK_K; + +// #if defined(__VXE__) || defined(__VXE2__) +// const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + +// uint32_t aux32[4]; +// const uint8_t * aux8 = (const uint8_t *)aux32; + +// float sumf = 0; + +// for (int i = 0; i < nb; ++i) { +// const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; +// const uint16_t * GGML_RESTRICT q2 = x[i].qs; +// const int8_t * GGML_RESTRICT q8 = y[i].qs; + +// float sumf1 = 0, sumf2 = 0; + +// for (int ib32 = 0; ib32 < QK_K/32; ib += 2) { +// int8x16_t q8b0 = vec_xl( 0, q8); +// int8x16_t qb81 = vec_xl(16, q8); +// int8x16_t q8b2 = vec_xl(32, q8); +// int8x16_t q8b3 = vec_xl(48, q8); +// q8 += 64; + +// memcpy(aux32, q2, 4 * sizeof(uint32_t)); +// q2 += 8; + +// int8x16_t q2u0 = { *(const int64_t *)(iq2xxs_grid + aux8[ 0]), *(const int64_t *)(iq2xxs_grid + aux8[ 1]) }; +// int8x16_t q2u1 = { *(const int64_t *)(iq2xxs_grid + aux8[ 2]), *(const int64_t *)(iq2xxs_grid + aux8[ 3]) }; +// int8x16_t q2u2 = { *(const int64_t *)(iq2xxs_grid + aux8[ 8]), *(const int64_t *)(iq2xxs_grid + aux8[ 9]) }; +// int8x16_t q2u3 = { *(const int64_t *)(iq2xxs_grid + aux8[10]), *(const int64_t *)(iq2xxs_grid + aux8[11]) }; + +// int8x16_t q2s0 = { *(const int64_t *)(signs64 + ((aux32[1] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 7) & 127)) }; +// int8x16_t q2s1 = { *(const int64_t *)(signs64 + ((aux32[1] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 21) & 127)) }; +// int8x16_t q2s2 = { *(const int64_t *)(signs64 + ((aux32[3] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 7) & 127)) }; +// int8x16_t q2s3 = { *(const int64_t *)(signs64 + ((aux32[3] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 21) & 127)) }; + +// q2u0 = vec_mul(q2u0, q2s0); +// q2u1 = vec_mul(q2u1, q2s1); +// q2u2 = vec_mul(q2u2, q2s2); +// q2u3 = vec_mul(q2u3, q2s3); + +// const int32x4_t p1 = ggml_vec_dot(ggml_vec_dot(vec_splat_s32(0), q2u0, q8b0), q2u1, q8b1); +// const int32x4_t p2 = ggml_vec_dot(ggml_vec_dot(vec_splat_s32(0), q2u2, q8b2), q2u3, q8b3); + +// sumf1 += (p1[0] + p1[1] + p1[2] + p1[3]) * (0.5f + (aux32[1] >> 28)); +// sumf2 += (p2[0] + p2[1] + p2[2] + p2[3]) * (0.5f + (aux32[3] >> 28)); +// } + +// sumf += d * (sumf1 + sumf2); +// } + +// *s = 0.25f * sumf; + +// #else + +// uint32_t aux32[2]; +// const uint8_t * aux8 = (const uint8_t *)aux32; + +// float sumf = 0.f; +// for (int i = 0; i < nb; ++i) { +// const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; +// const uint16_t * GGML_RESTRICT q2 = x[i].qs; +// const int8_t * GGML_RESTRICT q8 = y[i].qs; +// int32_t bsum = 0; +// for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { +// memcpy(aux32, q2, 2*sizeof(uint32_t)); +// q2 += 4; +// const uint32_t ls = 2*(aux32[1] >> 28) + 1; +// int32_t sumi = 0; +// for (int l = 0; l < 4; ++l) { +// const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); +// const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127]; +// for (int j = 0; j < 8; ++j) { +// sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); +// } +// q8 += 8; +// } +// bsum += sumi * ls; +// } +// sumf += d * bsum; +// } +// *s = 0.125f * sumf; +// #endif +// } + +void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK4_NL == 0); + static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same"); + + const block_iq4_nl * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + const int nb = n / QK4_NL; + + int ib = 0; + float sumf = 0; + +#if defined(__VXE__) || defined(__VXE2__) + const int8x16_t v_k = vec_xl(0, kvalues_iq4nl); + const uint8x16_t v_m = vec_splat_u8(0x0F); + + for (; ib < nb; ++ib) { + const block_iq4_nl * GGML_RESTRICT x0 = &x[ib]; + const block_q8_0 * GGML_RESTRICT y0 = &y[ib]; + + const uint8x16_t v_x = vec_xl(0, x0->qs); + int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m); + int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4); + + v_xl = vec_perm(v_k, v_k, (uchar8x16_t)v_xl); + v_xh = vec_perm(v_k, v_k, (uchar8x16_t)v_xh); + + const int8x16_t v_yl = vec_xl(0 , y0->qs); + const int8x16_t v_yh = vec_xl(QK8_0/2, y0->qs); + const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh); + + sumf += GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d) * (v_xy[0] + v_xy[1] + v_xy[2] + v_xy[3]); + } + +#endif + for (; ib < nb; ++ib) { + const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d); + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < QK4_NL/2; ++j) { + sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf]; + sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4]; + } + sumf += d * (sumi1 + sumi2); + } + *s = sumf; +} + +void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); + + const block_iq4_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__VXE__) || defined(__VXE2__) + const int8x16_t v_k = vec_xl(0, kvalues_iq4nl); + const uint8x16_t v_m = vec_splat_u8(0x0F); + + float sumf = 0; + + for (int ibl = 0; ibl < nb; ++ibl) { + const uint8_t * GGML_RESTRICT q4 = x[ibl].qs; + const int8_t * GGML_RESTRICT q8 = y[ibl].qs; + + uint16_t h = x[ibl].scales_h; + + int sumi1 = 0, sumi2 = 0; + for (int ib = 0; ib < QK_K/64; ++ib) { + const uint8x16_t v_x0 = vec_xl(0 , q4); + const uint8x16_t v_x1 = vec_xl(QK4_NL/2, q4); + q4 += 32; + + int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m); + int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4); + int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m); + int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4); + + v_x0l = vec_perm(v_k, v_k, (uchar8x16_t)v_x0l); + v_x0h = vec_perm(v_k, v_k, (uchar8x16_t)v_x0h); + v_x1l = vec_perm(v_k, v_k, (uchar8x16_t)v_x1l); + v_x1h = vec_perm(v_k, v_k, (uchar8x16_t)v_x1h); + + const int8x16_t v_y0 = vec_xl( 0, q8); + const int8x16_t v_y1 = vec_xl(16, q8); + const int8x16_t v_y2 = vec_xl(32, q8); + const int8x16_t v_y3 = vec_xl(48, q8); + q8 += 64; + + int32x4_t vsumi0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0l, v_y0), v_x0h, v_y1); + int32x4_t vsumi1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1l, v_y2), v_x1h, v_y3); + + int ls1 = ((x[ibl].scales_l[ib] & 0xF) | ((h << 4) & 0x30)) - 32; + int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32; + + h >>= 4; + + sumi1 += (vsumi0[0] + vsumi0[1] + vsumi0[2] + vsumi0[3]) * ls1; + sumi2 += (vsumi1[0] + vsumi1[1] + vsumi1[2] + vsumi1[3]) * ls2; + } + + sumf += GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2); + } + + *s = sumf; + +#else + float sumf = 0; + for (int ibl = 0; ibl < nb; ++ibl) { + const float d4d8 = GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d; + uint16_t h = x[ibl].scales_h; + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + for (int ib = 0; ib < QK_K/32; ib += 2) { + const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30); + const uint8_t ls2 = (x[ibl].scales_l[ib/2] >> 4) | ((h << 2) & 0x30); + h >>= 4; + const float d1 = d4d8*(ls1 - 32); + const float d2 = d4d8*(ls2 - 32); + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 16; ++j) { + sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; + sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; + } + sumf += d1 * (sumi1 + sumi2); + qs += 16; + q8 += 32; + sumi1 = sumi2 = 0; + for (int j = 0; j < 16; ++j) { + sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; + sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; + } + sumf += d2 * (sumi1 + sumi2); + qs += 16; + q8 += 32; + } + } + *s = sumf; +#endif +} + diff --git a/ggml/src/ggml-cpu/arch/wasm/quants.c b/ggml/src/ggml-cpu/arch/wasm/quants.c new file mode 100644 index 000000000..4ec97f533 --- /dev/null +++ b/ggml/src/ggml-cpu/arch/wasm/quants.c @@ -0,0 +1,1480 @@ +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" +#include "ggml-quants.h" +#include "ggml-impl.h" +#include "ggml-cpu.h" + +#include "../../quants.h" +#include "../../ggml-cpu-impl.h" + +#include +#include +#include +#include +#include // for qsort +#include // for GGML_ASSERT + +#define GROUP_MAX_EPS 1e-15f +#define GROUP_MAX_EPS_IQ3_XXS 1e-8f +#define GROUP_MAX_EPS_IQ2_S 1e-8f +#define GROUP_MAX_EPS_IQ1_M 1e-7f +#define GROUP_MAX_EPS_IQ1_S 1e-12f + +#define UNUSED GGML_UNUSED + +#if defined(__wasm_simd128__) +#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s +#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) +#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s) +#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s) +#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s) +#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s) +#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s) +#define B8(c,s ) B7(c,s, c), B7(c,s, s) + +// precomputed tables for expanding 8bits to 8 bytes: +static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4 +static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 +#endif + +void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0 * GGML_RESTRICT y = vy; + +#if defined __wasm_simd128__ + for (int i = 0; i < nb; i++) { + v128_t srcv [8]; + v128_t asrcv[8]; + v128_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), + wasm_f32x4_extract_lane(amaxv[0], 1)), + MAX(wasm_f32x4_extract_lane(amaxv[0], 2), + wasm_f32x4_extract_lane(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + for (int j = 0; j < 8; j++) { + const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); + const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); + + y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); + y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); + y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); + y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); + } + } +#else + GGML_UNUSED(nb); + // scalar + quantize_row_q8_0_ref(x, y, k); +#endif +} + +void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % QK8_1 == 0); + const int nb = k / QK8_1; + + block_q8_1 * GGML_RESTRICT y = vy; +#if defined __wasm_simd128__ + for (int i = 0; i < nb; i++) { + v128_t srcv [8]; + v128_t asrcv[8]; + v128_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), + wasm_f32x4_extract_lane(amaxv[0], 1)), + MAX(wasm_f32x4_extract_lane(amaxv[0], 2), + wasm_f32x4_extract_lane(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + v128_t accv = wasm_i32x4_splat(0); + + for (int j = 0; j < 8; j++) { + const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); + const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); + + y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); + y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); + y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); + y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); + + accv = wasm_i32x4_add(accv, vi); + } + + y[i].s = GGML_FP32_TO_FP16( + d * (wasm_i32x4_extract_lane(accv, 0) + + wasm_i32x4_extract_lane(accv, 1) + + wasm_i32x4_extract_lane(accv, 2) + + wasm_i32x4_extract_lane(accv, 3))); + } +#else + GGML_UNUSED(nb); + // scalar + quantize_row_q8_1_ref(x, y, k); +#endif +} + +//===================================== Q8_K ============================================== + +void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { +#ifdef __wasm_simd128__ + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + block_q8_K * GGML_RESTRICT yc = y; // Cast to proper type + + for (int i = 0; i < nb; i++) { + const float * x_block = x + i * QK_K; + + v128_t min_vec = wasm_v128_load(x_block); + v128_t max_vec = min_vec; + + for (int j = 4; j < QK_K; j += 4) { + v128_t x_vec = wasm_v128_load(x_block + j); + max_vec = wasm_f32x4_pmax(max_vec, x_vec); + min_vec = wasm_f32x4_pmin(min_vec, x_vec); + } + max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 2, 3, 0, 1)); + max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 1, 0, 3, 2)); + min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 2, 3, 0, 1)); + min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 1, 0, 3, 2)); + float max = wasm_f32x4_extract_lane(max_vec, 0); + float min = wasm_f32x4_extract_lane(min_vec, 0); + float amax = -min > max ? min : max; + + if (amax == 0.0f) { + yc[i].d = 0.0f; + const v128_t zero = wasm_i8x16_splat(0); + for (int j = 0; j < QK_K; j += 16) { + wasm_v128_store(yc[i].qs + j, zero); + } + continue; + } + + const float iscale = -127.0f / amax; + const v128_t scale_vec = wasm_f32x4_splat(iscale); + + // Process 16 elements per iteration + for (int j = 0, jb = 0; j < QK_K; j += 16, jb++) { + // Load and quantize 16 floats + v128_t x0 = wasm_v128_load(x_block + j); + v128_t x1 = wasm_v128_load(x_block + j + 4); + v128_t x2 = wasm_v128_load(x_block + j + 8); + v128_t x3 = wasm_v128_load(x_block + j + 12); + + v128_t q0 = wasm_f32x4_nearest(wasm_f32x4_mul(x0, scale_vec)); + v128_t q1 = wasm_f32x4_nearest(wasm_f32x4_mul(x1, scale_vec)); + v128_t q2 = wasm_f32x4_nearest(wasm_f32x4_mul(x2, scale_vec)); + v128_t q3 = wasm_f32x4_nearest(wasm_f32x4_mul(x3, scale_vec)); + + // Convert to i32 with saturation + v128_t i0 = wasm_i32x4_trunc_sat_f32x4(q0); + v128_t i1 = wasm_i32x4_trunc_sat_f32x4(q1); + v128_t i2 = wasm_i32x4_trunc_sat_f32x4(q2); + v128_t i3 = wasm_i32x4_trunc_sat_f32x4(q3); + + // Pack into 16 i8 values + v128_t i8 = wasm_i8x16_narrow_i16x8( + wasm_i16x8_narrow_i32x4(i0, i1), + wasm_i16x8_narrow_i32x4(i2, i3) + ); + wasm_v128_store(yc[i].qs + j, i8); + + // Calculate bsums using SIMD + v128_t sum16 = wasm_i16x8_add( + wasm_i16x8_extend_low_i8x16(i8), + wasm_i16x8_extend_high_i8x16(i8) + ); + v128_t sum32 = wasm_i32x4_add( + wasm_i32x4_extend_low_i16x8(sum16), + wasm_i32x4_extend_high_i16x8(sum16) + ); + sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 2, 3, 0, 1)); + sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 1, 0, 3, 2)); + yc[i].bsums[jb] = wasm_i32x4_extract_lane(sum32, 0); + } + + yc[i].d = 1.0f / iscale; + } +#else + quantize_row_q8_K_ref(x, y, k); +#endif +} + + +//===================================== Dot products ================================= + +void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + int ib = 0; + float sumf = 0; + +#if defined __wasm_simd128__ + v128_t sumv = wasm_f32x4_splat(0.0f); + + const v128_t m4b = wasm_i8x16_splat(0x0F); + const v128_t s8b = wasm_i8x16_splat(0x8); + + for (; ib + 1 < nb; ib += 2) { + const block_q4_0 * GGML_RESTRICT x0 = &x[ib]; + const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1]; + const block_q8_0 * GGML_RESTRICT y0 = &y[ib]; + const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; + + // Load and process x0 + v128_t v0_0 = wasm_v128_load(x0->qs); + v128_t v0_0l = wasm_v128_and(v0_0, m4b); + v128_t v0_0h = wasm_u8x16_shr(v0_0, 4); + v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b); + v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b); + + // Load y0 vectors + v128_t y0_l = wasm_v128_load(y0->qs); + v128_t y0_h = wasm_v128_load(y0->qs + 16); + + // Extend to i16x8 and compute dot products + v128_t dx0l = wasm_i16x8_extend_low_i8x16(v0_0ls); + v128_t dx0h = wasm_i16x8_extend_high_i8x16(v0_0ls); + v128_t dx0hl = wasm_i16x8_extend_low_i8x16(v0_0hs); + v128_t dx0hh = wasm_i16x8_extend_high_i8x16(v0_0hs); + + v128_t dy0ll = wasm_i16x8_extend_low_i8x16(y0_l); + v128_t dy0lh = wasm_i16x8_extend_high_i8x16(y0_l); + v128_t dy0hl = wasm_i16x8_extend_low_i8x16(y0_h); + v128_t dy0hh = wasm_i16x8_extend_high_i8x16(y0_h); + + v128_t dp0 = wasm_i32x4_add( + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(dx0l, dy0ll), + wasm_i32x4_dot_i16x8(dx0h, dy0lh) + ), + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(dx0hl, dy0hl), + wasm_i32x4_dot_i16x8(dx0hh, dy0hh) + ) + ); + + // Load and process x1 + v128_t v0_1 = wasm_v128_load(x1->qs); + v128_t v0_1l = wasm_v128_and(v0_1, m4b); + v128_t v0_1h = wasm_u8x16_shr(v0_1, 4); + v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b); + v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b); + + // Load y1 vectors + v128_t y1_l = wasm_v128_load(y1->qs); + v128_t y1_h = wasm_v128_load(y1->qs + 16); + + // Extend to i16x8 and compute dot products + v128_t dx1l = wasm_i16x8_extend_low_i8x16(v0_1ls); + v128_t dx1h = wasm_i16x8_extend_high_i8x16(v0_1ls); + v128_t dx1hl = wasm_i16x8_extend_low_i8x16(v0_1hs); + v128_t dx1hh = wasm_i16x8_extend_high_i8x16(v0_1hs); + + v128_t dy1ll = wasm_i16x8_extend_low_i8x16(y1_l); + v128_t dy1lh = wasm_i16x8_extend_high_i8x16(y1_l); + v128_t dy1hl = wasm_i16x8_extend_low_i8x16(y1_h); + v128_t dy1hh = wasm_i16x8_extend_high_i8x16(y1_h); + + v128_t dp1 = wasm_i32x4_add( + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(dx1l, dy1ll), + wasm_i32x4_dot_i16x8(dx1h, dy1lh) + ), + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(dx1hl, dy1hl), + wasm_i32x4_dot_i16x8(dx1hh, dy1hh) + ) + ); + + // Accumulate results with scaling + float scale0 = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d); + float scale1 = GGML_FP16_TO_FP32(x1->d) * GGML_FP16_TO_FP32(y1->d); + + sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp0), wasm_f32x4_splat(scale0))); + sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp1), wasm_f32x4_splat(scale1))); + } + + sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); + +#endif + for (; ib < nb; ++ib) { + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[ib].qs[j] & 0x0F) - 8; + const int v1 = (x[ib].qs[j] >> 4) - 8; + + sumi0 += (v0 * y[ib].qs[j]); + sumi1 += (v1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d); + } + + *s = sumf; +} + +void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + int ib = 0; + float sumf = 0; + + assert(n % qk == 0); + assert(qk == QK5_0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + +#if defined __wasm_simd128__ + v128_t sumv = wasm_f32x4_splat(0.0f); + + uint32_t qh_; + uint64_t tmp[4]; + + // TODO: check if unrolling this is better + for (; ib < nb; ++ib) { + const block_q5_0 * GGML_RESTRICT x0 = &x[ib]; + const block_q8_0 * GGML_RESTRICT y0 = &y[ib]; + + const v128_t m4b = wasm_i8x16_splat(0x0F); + + // extract the 5th bit + memcpy(&qh_, x0->qh, sizeof(qh_)); + + tmp[0] = table_b2b_1[(qh_ >> 0) & 0xFF]; + tmp[1] = table_b2b_1[(qh_ >> 8) & 0xFF]; + tmp[2] = table_b2b_1[(qh_ >> 16) & 0xFF]; + tmp[3] = table_b2b_1[(qh_ >> 24) ]; + + const v128_t qhl = wasm_v128_load(tmp + 0); + const v128_t qhh = wasm_v128_load(tmp + 2); + + const v128_t v0 = wasm_v128_load(x0->qs); + + // 4-bit -> 8-bit + const v128_t v0l = wasm_v128_and (v0, m4b); + const v128_t v0h = wasm_u8x16_shr(v0, 4); + + // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) + const v128_t v0lf = wasm_i8x16_sub(v0l, qhl); + const v128_t v0hf = wasm_i8x16_sub(v0h, qhh); + + // load y + const v128_t v1l = wasm_v128_load(y0->qs); + const v128_t v1h = wasm_v128_load(y0->qs + 16); + + // int8x16 -> int16x8 + const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); + const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); + const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); + const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); + + const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); + const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); + const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); + const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); + + // dot product + sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4( + wasm_i32x4_add( + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), + wasm_i32x4_dot_i16x8(v0lfh, v1lh)), + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), + wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), + wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d)))); + } + + sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); + +#endif + for (; ib < nb; ++ib) { + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); + + const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16); + const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16); + + sumi0 += (x0 * y[ib].qs[j]); + sumi1 += (x1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi; + } + + *s = sumf; +} + +void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; + + int ib = 0; + float sumf = 0; + + assert(n % qk == 0); + assert(qk == QK5_1); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_1 * GGML_RESTRICT x = vx; + const block_q8_1 * GGML_RESTRICT y = vy; + +#if defined __wasm_simd128__ + v128_t sumv = wasm_f32x4_splat(0.0f); + + float summs = 0.0f; + + uint32_t qh_; + uint64_t tmp[4]; + + // TODO: check if unrolling this is better + for (; ib < nb; ++ib) { + const block_q5_1 * GGML_RESTRICT x0 = &x[ib]; + const block_q8_1 * GGML_RESTRICT y0 = &y[ib]; + + summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s); + + const v128_t m4b = wasm_i8x16_splat(0x0F); + + // extract the 5th bit + memcpy(&qh_, x0->qh, sizeof(qh_)); + + tmp[0] = table_b2b_0[(qh_ >> 0) & 0xFF]; + tmp[1] = table_b2b_0[(qh_ >> 8) & 0xFF]; + tmp[2] = table_b2b_0[(qh_ >> 16) & 0xFF]; + tmp[3] = table_b2b_0[(qh_ >> 24) ]; + + const v128_t qhl = wasm_v128_load(tmp + 0); + const v128_t qhh = wasm_v128_load(tmp + 2); + + const v128_t v0 = wasm_v128_load(x0->qs); + + // 4-bit -> 8-bit + const v128_t v0l = wasm_v128_and (v0, m4b); + const v128_t v0h = wasm_u8x16_shr(v0, 4); + + // add high bit + const v128_t v0lf = wasm_v128_or(v0l, qhl); + const v128_t v0hf = wasm_v128_or(v0h, qhh); + + // load y + const v128_t v1l = wasm_v128_load(y0->qs); + const v128_t v1h = wasm_v128_load(y0->qs + 16); + + // int8x16 -> int16x8 + const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); + const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); + const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); + const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); + + const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); + const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); + const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); + const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); + + // dot product + sumv = wasm_f32x4_add(sumv, + wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add( + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), + wasm_i32x4_dot_i16x8(v0lfh, v1lh)), + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), + wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), + wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d)))); + } + + sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs; + +#endif + for (; ib < nb; ++ib) { + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; + + const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0; + const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1; + + sumi0 += (x0 * y[ib].qs[j]); + sumi1 += (x1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); + } + + *s = sumf; +} + +void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q8_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + int ib = 0; + float sumf = 0; + +#if defined __wasm_simd128__ + v128_t sumv = wasm_f32x4_splat(0.0f); + + for (; ib < nb; ++ib) { + const block_q8_0 * GGML_RESTRICT x0 = &x[ib]; + const block_q8_0 * GGML_RESTRICT y0 = &y[ib]; + + const v128_t x0_0 = wasm_v128_load(x0->qs); + const v128_t x0_1 = wasm_v128_load(x0->qs + 16); + const v128_t y0_0 = wasm_v128_load(y0->qs); + const v128_t y0_1 = wasm_v128_load(y0->qs + 16); + + // Extend 8-bit to 16-bit + const v128_t x0_0l = wasm_i16x8_extend_low_i8x16(x0_0); + const v128_t x0_0h = wasm_i16x8_extend_high_i8x16(x0_0); + const v128_t x0_1l = wasm_i16x8_extend_low_i8x16(x0_1); + const v128_t x0_1h = wasm_i16x8_extend_high_i8x16(x0_1); + + const v128_t y0_0l = wasm_i16x8_extend_low_i8x16(y0_0); + const v128_t y0_0h = wasm_i16x8_extend_high_i8x16(y0_0); + const v128_t y0_1l = wasm_i16x8_extend_low_i8x16(y0_1); + const v128_t y0_1h = wasm_i16x8_extend_high_i8x16(y0_1); + + // Compute dot products + const v128_t dx0_0 = wasm_i32x4_dot_i16x8(x0_0l, y0_0l); + const v128_t dx0_1 = wasm_i32x4_dot_i16x8(x0_0h, y0_0h); + const v128_t dx1_0 = wasm_i32x4_dot_i16x8(x0_1l, y0_1l); + const v128_t dx1_1 = wasm_i32x4_dot_i16x8(x0_1h, y0_1h); + + // Sum all dot products + const v128_t sum_dots = wasm_i32x4_add(wasm_i32x4_add(dx0_0, dx0_1), wasm_i32x4_add(dx1_0, dx1_1)); + + // Convert to float and accumulate + const float scale = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d); + sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(sum_dots), wasm_f32x4_splat(scale))); + } + + sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); + +#endif + for (; ib < nb; ++ib) { + int sumi = 0; + + for (int j = 0; j < qk; j++) { + sumi += x[ib].qs[j]*y[ib].qs[j]; + } + + sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)); + } + + *s = sumf; +} + +void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q2_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined __wasm_simd128__ + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + // Vectorized summs calculation + v128_t summs_vec = wasm_i32x4_splat(0); + { + v128_t sc_vec = wasm_v128_load(sc); + v128_t sc_upper = wasm_u8x16_shr(sc_vec, 4); + + v128_t sc_low = wasm_u16x8_extend_low_u8x16(sc_upper); + v128_t sc_high = wasm_u16x8_extend_high_u8x16(sc_upper); + + v128_t bsums1 = wasm_v128_load(&y[i].bsums[0]); + v128_t bsums2 = wasm_v128_load(&y[i].bsums[8]); + + summs_vec = wasm_i32x4_add( + wasm_i32x4_add(wasm_i32x4_dot_i16x8(sc_low, bsums1), + wasm_i32x4_dot_i16x8(sc_high, bsums2)), + summs_vec + ); + + summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 2, 3, 0, 1)); + summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 1, 0, 3, 2)); + } + int32_t summs = wasm_i32x4_extract_lane(summs_vec, 0); + + // Vectorized isum calculation + int32_t isum = 0; + const uint8_t * sc_ptr = sc; + const int k_iters = QK_K/128; + + for (int k = 0; k < k_iters; ++k) { + v128_t isum_vec = wasm_i32x4_splat(0); + int shift = 0; + + for (int j = 0; j < 4; ++j) { + const int d0 = (sc_ptr[0] & 0xF); + const int d1 = (sc_ptr[1] & 0xF); + sc_ptr += 2; + + // Process first 16 elements + v128_t q2_0 = wasm_v128_load(q2); + v128_t q8_0 = wasm_v128_load(q8); + v128_t q2_shift_0 = wasm_u8x16_shr(q2_0, shift); + v128_t q2_bits_0 = wasm_v128_and(q2_shift_0, wasm_i8x16_splat(0x03)); + + // Process next 16 elements + v128_t q2_1 = wasm_v128_load(q2 + 16); + v128_t q8_1 = wasm_v128_load(q8 + 16); + v128_t q2_shift_1 = wasm_u8x16_shr(q2_1, shift); + v128_t q2_bits_1 = wasm_v128_and(q2_shift_1, wasm_i8x16_splat(0x03)); + + // Calculate dot products + v128_t p0 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q8_0), + wasm_i16x8_extend_low_i8x16(q2_bits_0) + ); + v128_t p1 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q8_0), + wasm_i16x8_extend_high_i8x16(q2_bits_0) + ); + v128_t p2 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q8_1), + wasm_i16x8_extend_low_i8x16(q2_bits_1) + ); + v128_t p3 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q8_1), + wasm_i16x8_extend_high_i8x16(q2_bits_1) + ); + + // Accumulate scaled results + v128_t scaled = wasm_i32x4_add( + wasm_i32x4_mul(wasm_i32x4_add(p0, p1), wasm_i32x4_splat(d0)), + wasm_i32x4_mul(wasm_i32x4_add(p2, p3), wasm_i32x4_splat(d1)) + ); + + isum_vec = wasm_i32x4_add(isum_vec, scaled); + q8 += 32; + shift += 2; + } + q2 += 32; + + // Horizontal sum of isum_vec + isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 2, 3, 0, 1)); + isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 1, 0, 3, 2)); + isum += wasm_i32x4_extract_lane(isum_vec, 0); + } + + const float dall = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf += dall * isum - dmin * summs; + } + + *s = sumf; + +#else + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + int summs = 0; + for (int j = 0; j < 16; ++j) { + summs += y[i].bsums[j] * (sc[j] >> 4); + } + + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + int isum = 0; + int is = 0; + int d; + for (int k = 0; k < QK_K/128; ++k) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + d = sc[is++] & 0xF; + int isuml = 0; + for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + d = sc[is++] & 0xF; + isuml = 0; + for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + shift += 2; + q8 += 32; + } + q2 += 32; + } + sumf += dall * isum - dmin * summs; + } + *s = sumf; +#endif +} + +void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined __wasm_simd128__ + int8_t aux8[QK_K]; + float sums[8] = {0}; + uint32_t auxs[4]; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT hm = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + // Process blocks with SIMD + int8_t * a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K; j += 128) { + for (int shift = 0; shift <= 6; shift += 2) { + v128_t v_m = wasm_i8x16_splat(m); + for (int l = 0; l < 32; l += 16) { + v128_t v_q3 = wasm_v128_load(q3 + l); + v128_t v_shift = wasm_i8x16_shr(v_q3, shift); + v128_t v_low2 = wasm_v128_and(v_shift, wasm_i8x16_splat(0x03)); + + v128_t v_hm = wasm_v128_load(hm + l); + v128_t v_mask = wasm_v128_and(v_hm, v_m); + v_mask = wasm_i8x16_ne(v_mask, wasm_i8x16_splat(0)); + + v_low2 = wasm_i8x16_sub(v_low2, wasm_v128_and(wasm_i8x16_splat(4), wasm_v128_not(v_mask))); + wasm_v128_store(a + l, v_low2); + } + a += 32; + m <<= 1; + } + q3 += 32; + } + + // Extract scales + memcpy(auxs, x[i].scales, 12); + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + const int8_t * scales = (const int8_t *)auxs; + + // SIMD dot product with register accumulators + v128_t v_acc0 = wasm_i32x4_splat(0); + v128_t v_acc1 = wasm_i32x4_splat(0); + a = aux8; + for (int j = 0; j < QK_K/16; ++j) { + const v128_t v_scale = wasm_i16x8_splat(scales[j] - 32); + + // Process 16 elements per iteration + for (int k = 0; k < 2; ++k) { + const v128_t v_q8 = wasm_i16x8_load8x8(q8); + const v128_t v_a = wasm_i16x8_load8x8(a); + + v128_t v_prod = wasm_i16x8_mul(v_q8, v_a); + v_prod = wasm_i16x8_mul(v_prod, v_scale); + + v_acc0 = wasm_i32x4_add(v_acc0, wasm_i32x4_extend_low_i16x8(v_prod)); + v_acc1 = wasm_i32x4_add(v_acc1, wasm_i32x4_extend_high_i16x8(v_prod)); + + q8 += 8; + a += 8; + } + } + + // Accumulate results + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const v128_t v_d = wasm_f32x4_splat(d); + v128_t v_sum = wasm_f32x4_add( + wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc0), v_d), + wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc1), v_d) + ); + + // Accumulate into sums vector + wasm_v128_store(sums, wasm_f32x4_add(wasm_v128_load(sums), v_sum)); + } + + // Horizontal sum + v128_t v_sum = wasm_f32x4_add(wasm_v128_load(sums), wasm_v128_load(sums + 4)); + sumf = wasm_f32x4_extract_lane(v_sum, 0) + + wasm_f32x4_extract_lane(v_sum, 1) + + wasm_f32x4_extract_lane(v_sum, 2) + + wasm_f32x4_extract_lane(v_sum, 3); + + *s = sumf; + +#else + // scalar version + // This function is written like this so the compiler can manage to vectorize most of it + // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the + // manually vectorized version above. Every other version I tried would run at least 4 times slower. + // The ideal situation would be if we could just write the code once, and the compiler would + // automatically produce the best possible set of machine instructions, instead of us having to manually + // write vectorized versions for AVX, ARM_NEON, etc. + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + uint32_t auxs[4]; + const int8_t * scales = (const int8_t*)auxs; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT hm = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + q3 += 32; + } + a = aux8; + + memcpy(auxs, x[i].scales, 12); + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; + +#endif + +} + +void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#if defined __wasm_simd128__ + const uint8_t * scales = (const uint8_t*)&utmp[0]; + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Corrected sign + + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + // Process scales and mins + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + // Sum mins * q8sums + int32_t sumi = 0; + const int16_t * GGML_RESTRICT q8sums = y[i].bsums; + const uint8_t * m = (const uint8_t *)&utmp[2]; + for (int j = 0; j < 16; j += 2) { + sumi += (q8sums[j] + q8sums[j+1]) * m[j/2]; + } + sumf -= dmin * sumi; + + int32_t sumi1 = 0; + int32_t sumi2 = 0; + + for (int j = 0; j < QK_K/64; ++j) { + // Load 64 4-bit weights (32 bytes) + const v128_t q4x0 = wasm_v128_load(q4); + const v128_t q4x1 = wasm_v128_load(q4 + 16); + q4 += 32; + + // Split into low/high nibbles + const v128_t q4l0 = wasm_v128_and(q4x0, wasm_i8x16_splat(0x0F)); + const v128_t q4h0 = wasm_u8x16_shr(q4x0, 4); + const v128_t q4l1 = wasm_v128_and(q4x1, wasm_i8x16_splat(0x0F)); + const v128_t q4h1 = wasm_u8x16_shr(q4x1, 4); + + // Load 64 8-bit values (64 bytes) + const v128_t q8x0 = wasm_v128_load(q8); + const v128_t q8x1 = wasm_v128_load(q8 + 16); + const v128_t q8x2 = wasm_v128_load(q8 + 32); + const v128_t q8x3 = wasm_v128_load(q8 + 48); + q8 += 64; + + // Low nibble products + v128_t vacc1 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q4l0), + wasm_i16x8_extend_low_i8x16(q8x0) + ); + vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q4l0), + wasm_i16x8_extend_high_i8x16(q8x0) + )); + vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q4l1), + wasm_i16x8_extend_low_i8x16(q8x1) + )); + vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q4l1), + wasm_i16x8_extend_high_i8x16(q8x1) + )); + + // High nibble products + v128_t vacc2 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q4h0), + wasm_i16x8_extend_low_i8x16(q8x2) + ); + vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q4h0), + wasm_i16x8_extend_high_i8x16(q8x2) + )); + vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q4h1), + wasm_i16x8_extend_low_i8x16(q8x3) + )); + vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q4h1), + wasm_i16x8_extend_high_i8x16(q8x3) + )); + + // Accumulate scaled results + int32_t vacc1_sum = wasm_i32x4_extract_lane(vacc1, 0) + wasm_i32x4_extract_lane(vacc1, 1) + + wasm_i32x4_extract_lane(vacc1, 2) + wasm_i32x4_extract_lane(vacc1, 3); + sumi1 += vacc1_sum * scales[2*j]; + + int32_t vacc2_sum = wasm_i32x4_extract_lane(vacc2, 0) + wasm_i32x4_extract_lane(vacc2, 1) + + wasm_i32x4_extract_lane(vacc2, 2) + wasm_i32x4_extract_lane(vacc2, 3); + sumi2 += vacc2_sum * scales[2*j+1]; + } + + sumf += d * (sumi1 + sumi2); + } + + *s = sumf; + +#else + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + a += 32; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + a += 32; q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#if defined __wasm_simd128__ + //const uint8_t * scales = (const uint8_t*)&utmp[0]; + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Fixed sign + + const uint8_t * GGML_RESTRICT q5 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + // Process scales and mins + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + // Sum mins * q8sums + int32_t sumi_mins = 0; + const int16_t * GGML_RESTRICT q8sums = y[i].bsums; + const uint8_t * m = (const uint8_t *)&utmp[2]; + for (int j = 0; j < 16; j += 2) { + sumi_mins += (q8sums[j] + q8sums[j+1]) * m[j/2]; + } + sumf -= dmin * sumi_mins; // Correct subtraction + + v128_t qh0 = wasm_v128_load(qh); + v128_t qh1 = wasm_v128_load(qh + 16); + const uint8_t * sc = (const uint8_t *)utmp; + + int32_t sumi = 0; + + for (int j = 0; j < QK_K/64; ++j) { + const int shift = j * 2; + v128_t qh_shift0 = wasm_u8x16_shr(qh0, shift); + v128_t qh_shift1 = wasm_u8x16_shr(qh1, shift); + + v128_t qh_low0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x01)), 4); + v128_t qh_high0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x02)), 3); + v128_t qh_low1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x01)), 4); + v128_t qh_high1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x02)), 3); + + v128_t q5_0 = wasm_v128_load(q5); + v128_t q5_1 = wasm_v128_load(q5 + 16); + q5 += 32; + + v128_t q5l_0 = wasm_v128_or(wasm_v128_and(q5_0, wasm_i8x16_splat(0x0F)), qh_low0); + v128_t q5h_0 = wasm_v128_or(wasm_u8x16_shr(q5_0, 4), qh_high0); + v128_t q5l_1 = wasm_v128_or(wasm_v128_and(q5_1, wasm_i8x16_splat(0x0F)), qh_low1); + v128_t q5h_1 = wasm_v128_or(wasm_u8x16_shr(q5_1, 4), qh_high1); + + v128_t q8_0 = wasm_v128_load(q8); + v128_t q8_1 = wasm_v128_load(q8 + 16); + v128_t q8_2 = wasm_v128_load(q8 + 32); + v128_t q8_3 = wasm_v128_load(q8 + 48); + q8 += 64; + + // Process low quants + v128_t pl0 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q5l_0), + wasm_i16x8_extend_low_i8x16(q8_0) + ); + pl0 = wasm_i32x4_add(pl0, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q5l_0), + wasm_i16x8_extend_high_i8x16(q8_0) + )); + v128_t pl1 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q5l_1), + wasm_i16x8_extend_low_i8x16(q8_1) + ); + pl1 = wasm_i32x4_add(pl1, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q5l_1), + wasm_i16x8_extend_high_i8x16(q8_1) + )); + v128_t sum_low = wasm_i32x4_add(pl0, pl1); + + // Process high quants + v128_t ph0 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q5h_0), + wasm_i16x8_extend_low_i8x16(q8_2) + ); + ph0 = wasm_i32x4_add(ph0, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q5h_0), + wasm_i16x8_extend_high_i8x16(q8_2) + )); + v128_t ph1 = wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_low_i8x16(q5h_1), + wasm_i16x8_extend_low_i8x16(q8_3) + ); + ph1 = wasm_i32x4_add(ph1, wasm_i32x4_dot_i16x8( + wasm_i16x8_extend_high_i8x16(q5h_1), + wasm_i16x8_extend_high_i8x16(q8_3) + )); + v128_t sum_high = wasm_i32x4_add(ph0, ph1); + + // Accumulate with scale factors + int32_t sl = wasm_i32x4_extract_lane(sum_low, 0) + wasm_i32x4_extract_lane(sum_low, 1) + + wasm_i32x4_extract_lane(sum_low, 2) + wasm_i32x4_extract_lane(sum_low, 3); + int32_t sh = wasm_i32x4_extract_lane(sum_high, 0) + wasm_i32x4_extract_lane(sum_high, 1) + + wasm_i32x4_extract_lane(sum_high, 2) + wasm_i32x4_extract_lane(sum_high, 3); + + sumi += sl * sc[2*j] + sh * sc[2*j+1]; + } + + sumf += d * sumi; + } + + *s = sumf; + +#else + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const uint8_t * GGML_RESTRICT hm = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined __wasm_simd128__ + int8_t aux8[QK_K] __attribute__((aligned(16))); + int32_t aux32[8] __attribute__((aligned(16))) = {0}; + float sums[8] __attribute__((aligned(16))) = {0}; + + for (int i = 0; i < nb; ++i) { + // Unpack 6-bit quantized data into aux8 (unchanged) + const uint8_t * GGML_RESTRICT q4 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + int8_t * a = aux8; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + a += 128; + q4 += 64; + qh += 32; + } + + const int8_t * GGML_RESTRICT a_ptr = aux8; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + v128_t acc0 = wasm_i32x4_splat(0); + v128_t acc1 = wasm_i32x4_splat(0); + + for (int j = 0; j < QK_K/16; ++j) { + const int scale = x[i].scales[j]; + const v128_t vscale = wasm_i32x4_splat(scale); + + // Load 16 elements from a and q8 + const v128_t a_vec = wasm_v128_load(a_ptr); + const v128_t q8_vec = wasm_v128_load(q8); + + // Process low 8 elements + v128_t a_low = wasm_i16x8_extend_low_i8x16(a_vec); + v128_t q8_low = wasm_i16x8_extend_low_i8x16(q8_vec); + v128_t prod_low = wasm_i16x8_mul(a_low, q8_low); + v128_t prod_lo_lo = wasm_i32x4_extend_low_i16x8(prod_low); + v128_t prod_lo_hi = wasm_i32x4_extend_high_i16x8(prod_low); + + // Process high 8 elements + v128_t a_high = wasm_i16x8_extend_high_i8x16(a_vec); + v128_t q8_high = wasm_i16x8_extend_high_i8x16(q8_vec); + v128_t prod_high = wasm_i16x8_mul(a_high, q8_high); + v128_t prod_hi_lo = wasm_i32x4_extend_low_i16x8(prod_high); + v128_t prod_hi_hi = wasm_i32x4_extend_high_i16x8(prod_high); + + // Scale and accumulate + prod_lo_lo = wasm_i32x4_mul(prod_lo_lo, vscale); + prod_lo_hi = wasm_i32x4_mul(prod_lo_hi, vscale); + prod_hi_lo = wasm_i32x4_mul(prod_hi_lo, vscale); + prod_hi_hi = wasm_i32x4_mul(prod_hi_hi, vscale); + + acc0 = wasm_i32x4_add(acc0, wasm_i32x4_add(prod_lo_lo, prod_hi_lo)); + acc1 = wasm_i32x4_add(acc1, wasm_i32x4_add(prod_lo_hi, prod_hi_hi)); + + a_ptr += 16; + q8 += 16; + } + + // Store accumulated results + wasm_v128_store(&aux32[0], acc0); + wasm_v128_store(&aux32[4], acc1); + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) { + sums[l] += d * aux32[l]; + } + } + + // Sum final results + float sumf = 0; + for (int l = 0; l < 8; ++l) { + sumf += sums[l]; + } + *s = sumf; + +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q4 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + a += 128; + q4 += 64; + qh += 32; + } + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/16; ++j) { + int scale = x[i].scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + diff --git a/ggml/src/ggml-cpu/cpu-feats-x86.cpp b/ggml/src/ggml-cpu/arch/x86/cpu-feats.cpp similarity index 99% rename from ggml/src/ggml-cpu/cpu-feats-x86.cpp rename to ggml/src/ggml-cpu/arch/x86/cpu-feats.cpp index 902ee4346..d775a0363 100644 --- a/ggml/src/ggml-cpu/cpu-feats-x86.cpp +++ b/ggml/src/ggml-cpu/arch/x86/cpu-feats.cpp @@ -263,7 +263,7 @@ void test_x86_is() { static int ggml_backend_cpu_x86_score() { // FIXME: this does not check for OS support - int score = 0; + int score = 1; cpuid_x86 is; #ifdef GGML_FMA diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c new file mode 100644 index 000000000..e3f722b52 --- /dev/null +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -0,0 +1,4310 @@ +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" +#include "ggml-quants.h" +#include "ggml-impl.h" +#include "ggml-cpu.h" + +#include "../../quants.h" +#include "../../ggml-cpu-impl.h" + +#include +#include +#include +#include // for qsort +#include // for GGML_ASSERT + +#define GROUP_MAX_EPS 1e-15f +#define GROUP_MAX_EPS_IQ3_XXS 1e-8f +#define GROUP_MAX_EPS_IQ2_S 1e-8f +#define GROUP_MAX_EPS_IQ1_M 1e-7f +#define GROUP_MAX_EPS_IQ1_S 1e-12f + +#define UNUSED GGML_UNUSED + +// some compilers don't provide _mm256_set_m128i, e.g. gcc 7 +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) + +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) +// multiply int8_t, add results pairwise twice +static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { + // Get absolute values of x vectors + const __m128i ax = _mm_sign_epi8(x, x); + // Sign the values of the y vectors + const __m128i sy = _mm_sign_epi8(y, x); + // Perform multiplication and create 16-bit values + const __m128i dot = _mm_maddubs_epi16(ax, sy); + const __m128i ones = _mm_set1_epi16(1); + return _mm_madd_epi16(ones, dot); +} + +#if __AVX__ || __AVX2__ || __AVX512F__ +// horizontally add 8 floats +static inline float hsum_float_8(const __m256 x) { + __m128 res = _mm256_extractf128_ps(x, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(x)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + return _mm_cvtss_f32(res); +} + +// horizontally add 8 int32_t +static inline int hsum_i32_8(const __m256i a) { + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + +// horizontally add 4 int32_t +static inline int hsum_i32_4(const __m128i a) { + const __m128i hi64 = _mm_unpackhi_epi64(a, a); + const __m128i sum64 = _mm_add_epi32(hi64, a); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + +#if defined(__AVX2__) || defined(__AVX512F__) +// spread 32 bits to 32 bytes { 0x00, 0xFF } +static inline __m256i bytes_from_bits_32(const uint8_t * x) { + uint32_t x32; + memcpy(&x32, x, sizeof(uint32_t)); + const __m256i shuf_mask = _mm256_set_epi64x( + 0x0303030303030303, 0x0202020202020202, + 0x0101010101010101, 0x0000000000000000); + __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask); + const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe); + bytes = _mm256_or_si256(bytes, bit_mask); + return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1)); +} + +// Unpack 32 4-bit fields into 32 bytes +// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval +static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) +{ + const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi); + const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); + const __m256i lowMask = _mm256_set1_epi8( 0xF ); + return _mm256_and_si256(lowMask, bytes); +} + +// add int16_t pairwise and return as float vector +static inline __m256 sum_i16_pairs_float(const __m256i x) { + const __m256i ones = _mm256_set1_epi16(1); + const __m256i summed_pairs = _mm256_madd_epi16(ones, x); + return _mm256_cvtepi32_ps(summed_pairs); +} + +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { +#if defined(__AVX512VNNI__) && defined(__AVX512VL__) + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); + return _mm256_cvtepi32_ps(summed_pairs); +#elif defined(__AVXVNNI__) + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy); + return _mm256_cvtepi32_ps(summed_pairs); +#else + // Perform multiplication and create 16-bit values + const __m256i dot = _mm256_maddubs_epi16(ax, sy); + return sum_i16_pairs_float(dot); +#endif +} + +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { +#if __AVXVNNIINT8__ + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y); + return _mm256_cvtepi32_ps(summed_pairs); +#else + // Get absolute values of x vectors + const __m256i ax = _mm256_sign_epi8(x, x); + // Sign the values of the y vectors + const __m256i sy = _mm256_sign_epi8(y, x); + return mul_sum_us8_pairs_float(ax, sy); +#endif +} + +static inline __m128i packNibbles( __m256i bytes ) +{ + // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh +#if __AVX512F__ + const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000 + bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh + return _mm256_cvtepi16_epi8(bytes); // abcd_efgh +#else + const __m256i lowByte = _mm256_set1_epi16( 0xFF ); + __m256i high = _mm256_andnot_si256( lowByte, bytes ); + __m256i low = _mm256_and_si256( lowByte, bytes ); + high = _mm256_srli_epi16( high, 4 ); + bytes = _mm256_or_si256( low, high ); + + // Compress uint16_t lanes into bytes + __m128i r0 = _mm256_castsi256_si128( bytes ); + __m128i r1 = _mm256_extracti128_si256( bytes, 1 ); + return _mm_packus_epi16( r0, r1 ); +#endif +} +#elif defined(__AVX__) +static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 ) +{ + // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh + const __m128i lowByte = _mm_set1_epi16( 0xFF ); + __m128i high = _mm_andnot_si128( lowByte, bytes1 ); + __m128i low = _mm_and_si128( lowByte, bytes1 ); + high = _mm_srli_epi16( high, 4 ); + bytes1 = _mm_or_si128( low, high ); + high = _mm_andnot_si128( lowByte, bytes2 ); + low = _mm_and_si128( lowByte, bytes2 ); + high = _mm_srli_epi16( high, 4 ); + bytes2 = _mm_or_si128( low, high ); + + return _mm_packus_epi16( bytes1, bytes2); +} + +static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) { + const __m128i ax = _mm_sign_epi8(x, x); + const __m128i sy = _mm_sign_epi8(y, x); + return _mm_maddubs_epi16(ax, sy); +} + +// spread 32 bits to 32 bytes { 0x00, 0xFF } +static inline __m256i bytes_from_bits_32(const uint8_t * x) { + uint32_t x32; + memcpy(&x32, x, sizeof(uint32_t)); + const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); + const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202); + __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl); + __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh); + const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe); + bytesl = _mm_or_si128(bytesl, bit_mask); + bytesh = _mm_or_si128(bytesh, bit_mask); + bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1)); + bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1)); + return MM256_SET_M128I(bytesh, bytesl); +} + +// Unpack 32 4-bit fields into 32 bytes +// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval +static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) +{ + // Load 16 bytes from memory + __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi); + __m128i tmph = _mm_srli_epi16(tmpl, 4); + const __m128i lowMask = _mm_set1_epi8(0xF); + tmpl = _mm_and_si128(lowMask, tmpl); + tmph = _mm_and_si128(lowMask, tmph); + return MM256_SET_M128I(tmph, tmpl); +} + +// add int16_t pairwise and return as float vector +static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) { + const __m128i ones = _mm_set1_epi16(1); + const __m128i summed_pairsl = _mm_madd_epi16(ones, xl); + const __m128i summed_pairsh = _mm_madd_epi16(ones, xh); + const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl); + return _mm256_cvtepi32_ps(summed_pairs); +} + +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { + const __m128i axl = _mm256_castsi256_si128(ax); + const __m128i axh = _mm256_extractf128_si256(ax, 1); + const __m128i syl = _mm256_castsi256_si128(sy); + const __m128i syh = _mm256_extractf128_si256(sy, 1); + // Perform multiplication and create 16-bit values + const __m128i dotl = _mm_maddubs_epi16(axl, syl); + const __m128i doth = _mm_maddubs_epi16(axh, syh); + return sum_i16_pairs_float(doth, dotl); +} + +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { + const __m128i xl = _mm256_castsi256_si128(x); + const __m128i xh = _mm256_extractf128_si256(x, 1); + const __m128i yl = _mm256_castsi256_si128(y); + const __m128i yh = _mm256_extractf128_si256(y, 1); + // Get absolute values of x vectors + const __m128i axl = _mm_sign_epi8(xl, xl); + const __m128i axh = _mm_sign_epi8(xh, xh); + // Sign the values of the y vectors + const __m128i syl = _mm_sign_epi8(yl, xl); + const __m128i syh = _mm_sign_epi8(yh, xh); + // Perform multiplication and create 16-bit values + const __m128i dotl = _mm_maddubs_epi16(axl, syl); + const __m128i doth = _mm_maddubs_epi16(axh, syh); + return sum_i16_pairs_float(doth, dotl); +} + +// larger version of mul_sum_i8_pairs_float where x and y are each represented by four 128-bit vectors +static inline __m256 mul_sum_i8_quad_float(const __m128i x_1_0, const __m128i x_1_1, const __m128i x_2_0, const __m128i x_2_1, + const __m128i y_1_0, const __m128i y_1_1, const __m128i y_2_0, const __m128i y_2_1) { + const __m128i mone = _mm_set1_epi16(1); + + const __m128i p16_1_0 = mul_add_epi8_sse(x_1_0, y_1_0); + const __m128i p16_1_1 = mul_add_epi8_sse(x_1_1, y_1_1); + const __m128i p16_2_0 = mul_add_epi8_sse(x_2_0, y_2_0); + const __m128i p16_2_1 = mul_add_epi8_sse(x_2_1, y_2_1); + const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone); + const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone); + const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone); + const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone); + const __m128i p_1 = _mm_add_epi32(p_1_0, p_1_1); + const __m128i p_2 = _mm_add_epi32(p_2_0, p_2_1); + return _mm256_cvtepi32_ps(MM256_SET_M128I(p_2, p_1)); +} + +// quad fp16 delta calculation +static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const float x1, const float y1) { + // GGML_FP16_TO_FP32 is faster than Intel F16C + return _mm256_set_m128(_mm_set1_ps(GGML_FP16_TO_FP32(x1) * GGML_FP16_TO_FP32(y1)), + _mm_set1_ps(GGML_FP16_TO_FP32(x0) * GGML_FP16_TO_FP32(y0))); +} +#endif +#elif defined(__SSSE3__) +// horizontally add 4x4 floats +static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) { + __m128 res_0 =_mm_hadd_ps(a, b); + __m128 res_1 =_mm_hadd_ps(c, d); + __m128 res =_mm_hadd_ps(res_0, res_1); + res =_mm_hadd_ps(res, res); + res =_mm_hadd_ps(res, res); + + return _mm_cvtss_f32(res); +} +#endif // __AVX__ || __AVX2__ || __AVX512F__ +#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) + +void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0 * GGML_RESTRICT y = vy; + +#if defined(__AVX2__) || defined(__AVX__) + for (int i = 0; i < nb; i++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + // Quantize these floats + const float d = maxScalar / 127.f; + y[i].d = GGML_FP32_TO_FP16(d); + const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + + // Apply the multiplier + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + +#if defined(__AVX2__) + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + _mm256_storeu_si256((__m256i *)y[i].qs, i0); +#else + // Since we don't have in AVX some necessary functions, + // we split the registers in half and call AVX2 analogs from SSE + __m128i ni0 = _mm256_castsi256_si128( i0 ); + __m128i ni1 = _mm256_extractf128_si256( i0, 1); + __m128i ni2 = _mm256_castsi256_si128( i1 ); + __m128i ni3 = _mm256_extractf128_si256( i1, 1); + __m128i ni4 = _mm256_castsi256_si128( i2 ); + __m128i ni5 = _mm256_extractf128_si256( i2, 1); + __m128i ni6 = _mm256_castsi256_si128( i3 ); + __m128i ni7 = _mm256_extractf128_si256( i3, 1); + + // Convert int32 to int16 + ni0 = _mm_packs_epi32( ni0, ni1 ); + ni2 = _mm_packs_epi32( ni2, ni3 ); + ni4 = _mm_packs_epi32( ni4, ni5 ); + ni6 = _mm_packs_epi32( ni6, ni7 ); + // Convert int16 to int8 + ni0 = _mm_packs_epi16( ni0, ni2 ); + ni4 = _mm_packs_epi16( ni4, ni6 ); + + _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); + _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); +#endif + } +#else + GGML_UNUSED(nb); + // scalar + quantize_row_q8_0_ref(x, y, k); +#endif +} + +void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % QK8_1 == 0); + const int nb = k / QK8_1; + + block_q8_1 * GGML_RESTRICT y = vy; +#if defined(__AVX2__) || defined(__AVX__) + for (int i = 0; i < nb; i++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float max_scalar = _mm_cvtss_f32( max4 ); + + // Quantize these floats + const float d = max_scalar / 127.f; + y[i].d = GGML_FP32_TO_FP16(d); + const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + + // Apply the multiplier + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + +#if defined(__AVX2__) + // Compute the sum of the quants and set y[i].s + y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); + + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + _mm256_storeu_si256((__m256i *)y[i].qs, i0); +#else + // Since we don't have in AVX some necessary functions, + // we split the registers in half and call AVX2 analogs from SSE + __m128i ni0 = _mm256_castsi256_si128( i0 ); + __m128i ni1 = _mm256_extractf128_si256( i0, 1); + __m128i ni2 = _mm256_castsi256_si128( i1 ); + __m128i ni3 = _mm256_extractf128_si256( i1, 1); + __m128i ni4 = _mm256_castsi256_si128( i2 ); + __m128i ni5 = _mm256_extractf128_si256( i2, 1); + __m128i ni6 = _mm256_castsi256_si128( i3 ); + __m128i ni7 = _mm256_extractf128_si256( i3, 1); + + // Compute the sum of the quants and set y[i].s + const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3)); + const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7)); + y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_4(_mm_add_epi32(s0, s1))); + + // Convert int32 to int16 + ni0 = _mm_packs_epi32( ni0, ni1 ); + ni2 = _mm_packs_epi32( ni2, ni3 ); + ni4 = _mm_packs_epi32( ni4, ni5 ); + ni6 = _mm_packs_epi32( ni6, ni7 ); + // Convert int16 to int8 + ni0 = _mm_packs_epi16( ni0, ni2 ); + ni4 = _mm_packs_epi16( ni4, ni6 ); + + _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); + _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); +#endif + } +#else + GGML_UNUSED(nb); + // scalar + quantize_row_q8_1_ref(x, y, k); +#endif +} + +// placeholder implementation for Apple targets +void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + quantize_row_q8_K_ref(x, y, k); +} + +//===================================== Dot products ================================= + +// +// Helper functions +// + +#if __AVX__ || __AVX2__ || __AVX512F__ + +// shuffles to pick the required scales in dot products +static inline __m256i get_scale_shuffle_q3k(int i) { + static const uint8_t k_shuffle[128] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, + }; + return _mm256_loadu_si256((const __m256i*)k_shuffle + i); +} +static inline __m256i get_scale_shuffle_k4(int i) { + static const uint8_t k_shuffle[256] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, + 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, + 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, + 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15 + }; + return _mm256_loadu_si256((const __m256i*)k_shuffle + i); +} +static inline __m128i get_scale_shuffle(int i) { + static const uint8_t k_shuffle[128] = { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, + 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, + 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11, + 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13, + 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15 + }; + return _mm_loadu_si128((const __m128i*)k_shuffle + i); +} +#endif + +void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + int ib = 0; + float sumf = 0; + +#if defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (; ib < nb; ++ib) { + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) ); + + __m256i qx = bytes_from_nibbles_32(x[ib].qs); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m256i off = _mm256_set1_epi8( 8 ); + qx = _mm256_sub_epi8( qx, off ); + + __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); + + const __m256 q = mul_sum_i8_pairs_float(qx, qy); + + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps( d, q, acc ); + } + + sumf = hsum_float_8(acc); +#elif defined(__AVX__) + __m256 accum = _mm256_setzero_ps(); + for (; ib + 1 < nb; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs); + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); + const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs); + const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1); + const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); + const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1); + + const __m128i q4b_1_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_1), _mm_set1_epi8(8)); + const __m128i q4b_1_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_1, 4)), _mm_set1_epi8(8)); + const __m128i q4b_2_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_2), _mm_set1_epi8(8)); + const __m128i q4b_2_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_2, 4)), _mm_set1_epi8(8)); + + const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0); + const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1); + const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0); + const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1); + const __m128i p_1 = _mm_add_epi16(p16_1_0, p16_1_1); + const __m128i p_2 = _mm_add_epi16(p16_2_0, p16_2_1); + const __m256 p = sum_i16_pairs_float(p_2, p_1); + + const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d); + accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum); + } + + sumf = hsum_float_8(accum); +#elif defined(__SSSE3__) + // set constants + const __m128i lowMask = _mm_set1_epi8(0xF); + const __m128i off = _mm_set1_epi8(8); + + // Initialize accumulator with zeros + __m128 acc_0 = _mm_setzero_ps(); + __m128 acc_1 = _mm_setzero_ps(); + __m128 acc_2 = _mm_setzero_ps(); + __m128 acc_3 = _mm_setzero_ps(); + + for (; ib + 1 < nb; ib += 2) { + _mm_prefetch(&x[ib] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[ib] + sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 0 and 1 + const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) ); + + const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[ib].qs); + + __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); + __m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs); + bx_0 = _mm_sub_epi8(bx_0, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); + + __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); + __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16)); + bx_1 = _mm_sub_epi8(bx_1, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); + + _mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 2 and 3 + const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) ); + + const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); + + __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); + __m128i by_2 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); + bx_2 = _mm_sub_epi8(bx_2, off); + const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); + + __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); + __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[ib + 1].qs + 16)); + bx_3 = _mm_sub_epi8(bx_3, off); + const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); + + // Convert int32_t to float + __m128 p0 = _mm_cvtepi32_ps(i32_0); + __m128 p1 = _mm_cvtepi32_ps(i32_1); + __m128 p2 = _mm_cvtepi32_ps(i32_2); + __m128 p3 = _mm_cvtepi32_ps(i32_3); + + // Apply the scale + __m128 p0_d = _mm_mul_ps( d_0_1, p0 ); + __m128 p1_d = _mm_mul_ps( d_0_1, p1 ); + __m128 p2_d = _mm_mul_ps( d_2_3, p2 ); + __m128 p3_d = _mm_mul_ps( d_2_3, p3 ); + + // Acummulate + acc_0 = _mm_add_ps(p0_d, acc_0); + acc_1 = _mm_add_ps(p1_d, acc_1); + acc_2 = _mm_add_ps(p2_d, acc_2); + acc_3 = _mm_add_ps(p3_d, acc_3); + } + + sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); + +#endif + for (; ib < nb; ++ib) { + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[ib].qs[j] & 0x0F) - 8; + const int v1 = (x[ib].qs[j] >> 4) - 8; + + sumi0 += (v0 * y[ib].qs[j]); + sumi1 += (v1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d); + } + + *s = sumf; +} + +void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_1 * GGML_RESTRICT x = vx; + const block_q8_1 * GGML_RESTRICT y = vy; + + int ib = 0; + float sumf = 0; + +#if defined(__AVX2__) || defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + float summs = 0; + + // Main loop + for (; ib < nb; ++ib) { + const float d0 = GGML_FP16_TO_FP32(x[ib].d); + const float d1 = GGML_FP16_TO_FP32(y[ib].d); + + summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); + + const __m256 d0v = _mm256_set1_ps( d0 ); + const __m256 d1v = _mm256_set1_ps( d1 ); + + // Compute combined scales + const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); + + // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes + const __m256i qx = bytes_from_nibbles_32(x[ib].qs); + const __m256i qy = _mm256_loadu_si256( (const __m256i *)y[ib].qs ); + + const __m256 xy = mul_sum_us8_pairs_float(qx, qy); + + // Accumulate d0*d1*x*y +#if defined(__AVX2__) + acc = _mm256_fmadd_ps( d0d1, xy, acc ); +#else + acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc ); +#endif + } + + sumf = hsum_float_8(acc) + summs; + +#endif + for (; ib < nb; ++ib) { + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[ib].qs[j] & 0x0F); + const int v1 = (x[ib].qs[j] >> 4); + + sumi0 += (v0 * y[ib].qs[j]); + sumi1 += (v1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); + } + + *s = sumf; +} + +void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + int ib = 0; + float sumf = 0; + + assert(n % qk == 0); + assert(qk == QK5_0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + +#if defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (; ib < nb; ++ib) { + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); + + __m256i qx = bytes_from_nibbles_32(x[ib].qs); + __m256i bxhi = bytes_from_bits_32(x[ib].qh); + bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0)); + qx = _mm256_or_si256(qx, bxhi); + + __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); + + const __m256 q = mul_sum_i8_pairs_float(qx, qy); + + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps(d, q, acc); + } + + sumf = hsum_float_8(acc); +#elif defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + __m128i mask = _mm_set1_epi8((char)0xF0); + + // Main loop + for (; ib < nb; ++ib) { + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); + + __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs); + const __m256i bxhi = bytes_from_bits_32(x[ib].qh); + __m128i bxhil = _mm256_castsi256_si128(bxhi); + __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); + bxhil = _mm_andnot_si128(bxhil, mask); + bxhih = _mm_andnot_si128(bxhih, mask); + __m128i bxl = _mm256_castsi256_si128(bx_0); + __m128i bxh = _mm256_extractf128_si256(bx_0, 1); + bxl = _mm_or_si128(bxl, bxhil); + bxh = _mm_or_si128(bxh, bxhih); + bx_0 = MM256_SET_M128I(bxh, bxl); + + const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs); + + const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0); + + /* Multiply q with scale and accumulate */ + acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc); + } + + sumf = hsum_float_8(acc); + +#endif + for (; ib < nb; ++ib) { + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); + + const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16); + const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16); + + sumi0 += (x0 * y[ib].qs[j]); + sumi1 += (x1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi; + } + + *s = sumf; +} + +void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; + + int ib = 0; + float sumf = 0; + + assert(n % qk == 0); + assert(qk == QK5_1); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_1 * GGML_RESTRICT x = vx; + const block_q8_1 * GGML_RESTRICT y = vy; + +#if defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + float summs = 0.0f; + + // Main loop + for (; ib < nb; ++ib) { + const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d)); + + summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); + + __m256i qx = bytes_from_nibbles_32(x[ib].qs); + __m256i bxhi = bytes_from_bits_32(x[ib].qh); + bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10)); + qx = _mm256_or_si256(qx, bxhi); + + const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[ib].d)); + const __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); + + const __m256 q = mul_sum_us8_pairs_float(qx, qy); + + acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); + } + + sumf = hsum_float_8(acc) + summs; +#elif defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + __m128i mask = _mm_set1_epi8(0x10); + + float summs = 0.0f; + + // Main loop + for (; ib < nb; ++ib) { + const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d)); + + summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); + + __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs); + const __m256i bxhi = bytes_from_bits_32(x[ib].qh); + __m128i bxhil = _mm256_castsi256_si128(bxhi); + __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); + bxhil = _mm_and_si128(bxhil, mask); + bxhih = _mm_and_si128(bxhih, mask); + __m128i bxl = _mm256_castsi256_si128(bx_0); + __m128i bxh = _mm256_extractf128_si256(bx_0, 1); + bxl = _mm_or_si128(bxl, bxhil); + bxh = _mm_or_si128(bxh, bxhih); + bx_0 = MM256_SET_M128I(bxh, bxl); + + const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[ib].d)); + const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs); + + const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0); + + acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc); + } + + sumf = hsum_float_8(acc) + summs; + +#endif + for (; ib < nb; ++ib) { + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; + + const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0; + const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1; + + sumi0 += (x0 * y[ib].qs[j]); + sumi1 += (x1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); + } + + *s = sumf; +} + +void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q8_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + int ib = 0; + float sumf = 0; + +#if defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (; ib < nb; ++ib) { + // Compute combined scale for the block + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); + __m256i qx = _mm256_loadu_si256((const __m256i *)x[ib].qs); + __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); + + const __m256 q = mul_sum_i8_pairs_float(qx, qy); + + // Multiply q with scale and accumulate + acc = _mm256_fmadd_ps( d, q, acc ); + } + + sumf = hsum_float_8(acc); +#elif defined(__AVX__) + __m256 accum = _mm256_setzero_ps(); + + for (; ib + 1 < nb; ib += 2) { + const __m128i qx_1_0 = _mm_loadu_si128((const __m128i *)x[ib].qs); + const __m128i qx_1_1 = _mm_loadu_si128((const __m128i *)x[ib].qs + 1); + const __m128i qx_2_0 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); + const __m128i qx_2_1 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs + 1); + const __m128i qy_1_0 = _mm_loadu_si128((const __m128i *)y[ib].qs); + const __m128i qy_1_1 = _mm_loadu_si128((const __m128i *)y[ib].qs + 1); + const __m128i qy_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); + const __m128i qy_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1); + + const __m256 p = mul_sum_i8_quad_float(qx_1_0, qx_1_1, qx_2_0, qx_2_1, qy_1_0, qy_1_1, qy_2_0, qy_2_1); + const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d); + accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum); + } + + sumf = hsum_float_8(accum); + +#endif + for (; ib < nb; ++ib) { + int sumi = 0; + + for (int j = 0; j < qk; j++) { + sumi += x[ib].qs[j]*y[ib].qs[j]; + } + + sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)); + } + + *s = sumf; +} + +void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq1_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__AVX2__) + __m256 sumf = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + // 16-bit sums + __m256i sumi0 = _mm256_setzero_si256(); + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + + // first 32 bytes of 5 elements + { + __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs)); + // 8-bit multiplies with shifts, masks and adds + __m256i qx1 = _mm256_add_epi8(qx0, _mm256_add_epi8(qx0, qx0)); // 1 * 3 + __m256i qx2 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx0, 3), _mm256_set1_epi8(-8)), qx0); // 1 * 9 + __m256i qx3 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx1, 3), _mm256_set1_epi8(-8)), qx1); // 3 * 9 + __m256i qx4 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx2, 3), _mm256_set1_epi8(-8)), qx2); // 9 * 9 + + // TODO: can _mm256_mulhi_epu16 be faster even if 16-bits? + + // Cancel the +1 from avg so that it behaves like a halving add + qx0 = _mm256_subs_epu8(qx0, _mm256_set1_epi8(1)); + qx1 = _mm256_subs_epu8(qx1, _mm256_set1_epi8(1)); + qx2 = _mm256_subs_epu8(qx2, _mm256_set1_epi8(1)); + qx3 = _mm256_subs_epu8(qx3, _mm256_set1_epi8(1)); + qx4 = _mm256_subs_epu8(qx4, _mm256_set1_epi8(1)); + // Multiply by 3 and get the top 2 bits + qx0 = _mm256_avg_epu8(qx0, _mm256_avg_epu8(qx0, _mm256_setzero_si256())); + qx1 = _mm256_avg_epu8(qx1, _mm256_avg_epu8(qx1, _mm256_setzero_si256())); + qx2 = _mm256_avg_epu8(qx2, _mm256_avg_epu8(qx2, _mm256_setzero_si256())); + qx3 = _mm256_avg_epu8(qx3, _mm256_avg_epu8(qx3, _mm256_setzero_si256())); + qx4 = _mm256_avg_epu8(qx4, _mm256_avg_epu8(qx4, _mm256_setzero_si256())); + qx0 = _mm256_and_si256(_mm256_srli_epi16(qx0, 6), _mm256_set1_epi8(3)); + qx1 = _mm256_and_si256(_mm256_srli_epi16(qx1, 6), _mm256_set1_epi8(3)); + qx2 = _mm256_and_si256(_mm256_srli_epi16(qx2, 6), _mm256_set1_epi8(3)); + qx3 = _mm256_and_si256(_mm256_srli_epi16(qx3, 6), _mm256_set1_epi8(3)); + qx4 = _mm256_and_si256(_mm256_srli_epi16(qx4, 6), _mm256_set1_epi8(3)); + + const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 0)); + const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 32)); + const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 64)); + const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 96)); + const __m256i qy4 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 128)); + + qx0 = _mm256_maddubs_epi16(qx0, qy0); + qx1 = _mm256_maddubs_epi16(qx1, qy1); + qx2 = _mm256_maddubs_epi16(qx2, qy2); + qx3 = _mm256_maddubs_epi16(qx3, qy3); + qx4 = _mm256_maddubs_epi16(qx4, qy4); + + sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1)); + sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3)); + sumi2 = _mm256_add_epi16(sumi2, qx4); + } + + // last 16 bytes of 5-element, along with the 4 bytes of 4 elements + { + __m128i qx0 = _mm_loadu_si128((const __m128i *) (x[i].qs + 32)); + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned + __m256i qx5_l = _mm256_cvtepu8_epi16(_mm_set1_epi32(qh)); + __m128i qx1 = _mm_add_epi8(qx0, _mm_add_epi8(qx0, qx0)); // 1 * 3 + __m128i qx2 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx0, 3), _mm_set1_epi8(-8)), qx0); // 1 * 9 + __m128i qx3 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx1, 3), _mm_set1_epi8(-8)), qx1); // 3 * 9 + __m128i qx4 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx2, 3), _mm_set1_epi8(-8)), qx2); // 9 * 9 + __m256i qx01 = MM256_SET_M128I(qx1, qx0); + __m256i qx23 = MM256_SET_M128I(qx3, qx2); + + // avx2 does not have 8-bit multiplies, so 16-bit it is. + qx5_l = _mm256_mullo_epi16(qx5_l, _mm256_set_epi16(27, 27, 27, 27, 9, 9, 9, 9, 3, 3, 3, 3, 1, 1, 1, 1)); + qx5_l = _mm256_and_si256(qx5_l, _mm256_set1_epi16(0xFF)); + __m128i qx5 = _mm_packus_epi16(_mm256_castsi256_si128(qx5_l), _mm256_extracti128_si256(qx5_l, 1)); + + __m256i qx45 = MM256_SET_M128I(qx5, qx4); + + // Cancel the +1 from avg so that it behaves like a halving add + qx01 = _mm256_subs_epu8(qx01, _mm256_set1_epi8(1)); + qx23 = _mm256_subs_epu8(qx23, _mm256_set1_epi8(1)); + qx45 = _mm256_subs_epu8(qx45, _mm256_set1_epi8(1)); + // Multiply by 3 and get the top 2 bits + qx01 = _mm256_avg_epu8(qx01, _mm256_avg_epu8(qx01, _mm256_setzero_si256())); + qx23 = _mm256_avg_epu8(qx23, _mm256_avg_epu8(qx23, _mm256_setzero_si256())); + qx45 = _mm256_avg_epu8(qx45, _mm256_avg_epu8(qx45, _mm256_setzero_si256())); + qx01 = _mm256_and_si256(_mm256_srli_epi16(qx01, 6), _mm256_set1_epi8(3)); + qx23 = _mm256_and_si256(_mm256_srli_epi16(qx23, 6), _mm256_set1_epi8(3)); + qx45 = _mm256_and_si256(_mm256_srli_epi16(qx45, 6), _mm256_set1_epi8(3)); + + const __m256i qy01 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 160)); + const __m256i qy23 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 192)); + const __m256i qy45 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 224)); + + qx01 = _mm256_maddubs_epi16(qx01, qy01); + qx23 = _mm256_maddubs_epi16(qx23, qy23); + qx45 = _mm256_maddubs_epi16(qx45, qy45); + + sumi0 = _mm256_add_epi16(sumi0, qx01); + sumi1 = _mm256_add_epi16(sumi1, qx23); + sumi2 = _mm256_add_epi16(sumi2, qx45); + } + + const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums); + const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)); + + sumi0 = _mm256_sub_epi16(sumi0, ysum); + sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(sumi1, sumi2)); + sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1)); + + sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf); + } + + *s = hsum_float_8(sumf); + +#else + const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243}; + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + int sum = 0; + + for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) { + for (size_t l = 0; l < 5; ++l) { + for (size_t m = 0; m < 32; ++m) { + uint8_t q = x[i].qs[j + m] * pow3[l]; + uint16_t xi = ((uint16_t) q * 3) >> 8; + sum += (xi - 1) * y[i].qs[j*5 + l*32 + m]; + } + } + } + for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) { + for (size_t l = 0; l < 5; ++l) { + for (size_t m = 0; m < 16; ++m) { + uint8_t q = x[i].qs[j + m] * pow3[l]; + uint16_t xi = ((uint16_t) q * 3) >> 8; + sum += (xi - 1) * y[i].qs[j*5 + l*16 + m]; + } + } + } + + for (size_t l = 0; l < 4; ++l) { + for (size_t j = 0; j < sizeof(x->qh); ++j) { + uint8_t q = x[i].qh[j] * pow3[l]; + uint16_t xi = ((uint16_t) q * 3) >> 8; + sum += (xi - 1) * y[i].qs[sizeof(x->qs)*5 + l*sizeof(x->qh) + j]; + } + } + + sumf += (float) sum * (GGML_FP16_TO_FP32(x[i].d) * y[i].d); + } + + *s = sumf; +#endif +} + +void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq2_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__AVX2__) + __m256 sumf = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + // 16-bit sums, because 256*127 still fits + __m256i sumi0 = _mm256_setzero_si256(); + __m256i sumi1 = _mm256_setzero_si256(); + + for (size_t j = 0; j < sizeof(x->qs); j += 32) { + __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs + j)); + __m256i qx1 = _mm256_srli_epi16(qx0, 2); + __m256i qx2 = _mm256_srli_epi16(qx0, 4); + __m256i qx3 = _mm256_srli_epi16(qx0, 6); + + // 0, 1, 2 (should not be 3) + qx0 = _mm256_and_si256(qx0, _mm256_set1_epi8(3)); + qx1 = _mm256_and_si256(qx1, _mm256_set1_epi8(3)); + qx2 = _mm256_and_si256(qx2, _mm256_set1_epi8(3)); + qx3 = _mm256_and_si256(qx3, _mm256_set1_epi8(3)); + + const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 0)); + const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 32)); + const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 64)); + const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 96)); + + qx0 = _mm256_maddubs_epi16(qx0, qy0); + qx1 = _mm256_maddubs_epi16(qx1, qy1); + qx2 = _mm256_maddubs_epi16(qx2, qy2); + qx3 = _mm256_maddubs_epi16(qx3, qy3); + + sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1)); + sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3)); + } + + const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums); + const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)); + + sumi0 = _mm256_add_epi16(sumi0, sumi1); + sumi0 = _mm256_sub_epi16(sumi0, ysum); + sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1)); + + sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf); + } + + *s = hsum_float_8(sumf); + +#else + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + int32_t sumi = 0; + + for (size_t j = 0; j < sizeof(x->qs); j += 32) { + for (size_t l = 0; l < 4; ++l) { + for (size_t k = 0; k < 32; ++k) { + sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].qs[j + k] >> (l*2)) & 3) - 1); + } + } + } + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + sumf += (float) sumi * d; + } + + *s = sumf; +#endif +} + +void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q2_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + const __m128i m4 = _mm_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * GGML_RESTRICT q2 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m128i scales8 = _mm_and_si128(mins_and_scales, m4); + const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); + const __m256i mins = _mm256_cvtepi8_epi16(mins8); + const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums)); + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc); + + const __m256i all_scales = _mm256_cvtepi8_epi16(scales8); + const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); + const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); + const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; + + __m256i sumi = _mm256_setzero_si256(); + + for (int j = 0; j < QK_K/128; ++j) { + + const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32; + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + const __m256i q2_0 = _mm256_and_si256(q2bits, m3); + const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3); + const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3); + const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3); + + __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0); + __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1); + __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2); + __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3); + + p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0); + p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1); + p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2); + p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3); + + p0 = _mm256_add_epi32(p0, p1); + p2 = _mm256_add_epi32(p2, p3); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2)); + } + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(0x3); + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i m2 = _mm_set1_epi8(0x2); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * GGML_RESTRICT q2 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + // load mins and scales from block_q2_K.scales[QK_K/16] + const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m128i scales16 = _mm_and_si128(mins_and_scales, m4); + const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); + const __m128i mins_0 = _mm_cvtepi8_epi16(mins16); + const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16)); + + // summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2 + const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0])); + const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8])); + + // sumf += -dmin * summs in 32bits*8 + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc); + + const __m128i scales_0 = _mm_cvtepi8_epi16(scales16); + const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16)); + const __m128i scales[2] = { scales_0, scales_1 }; + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + for (int j = 0; j < QK_K/128; ++j) { + + // load Q8 quants int8*16*8 from block_q8_K.qs[QK_K] + const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + + // load 2bits*16*8 from block_q2_K.qs[QK_K/4] + __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; + const __m128i q2_0 = _mm_and_si128(q2bits, m3); + const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); + const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); + const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); + q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; + const __m128i q2_1 = _mm_and_si128(q2bits, m3); + const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); + const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); + const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); + + // isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8 + __m128i p0 = _mm_maddubs_epi16(q2_0, q8_0); + __m128i p1 = _mm_maddubs_epi16(q2_1, q8_1); + __m128i p2 = _mm_maddubs_epi16(q2_2, q8_2); + __m128i p3 = _mm_maddubs_epi16(q2_3, q8_3); + __m128i p4 = _mm_maddubs_epi16(q2_4, q8_4); + __m128i p5 = _mm_maddubs_epi16(q2_5, q8_5); + __m128i p6 = _mm_maddubs_epi16(q2_6, q8_6); + __m128i p7 = _mm_maddubs_epi16(q2_7, q8_7); + + // isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8 + __m128i shuffle = _mm_set1_epi16(0x0100); + p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0); + shuffle = _mm_add_epi16(shuffle, m2); + p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1); + shuffle = _mm_add_epi16(shuffle, m2); + p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2); + shuffle = _mm_add_epi16(shuffle, m2); + p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3); + shuffle = _mm_add_epi16(shuffle, m2); + p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4); + shuffle = _mm_add_epi16(shuffle, m2); + p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5); + shuffle = _mm_add_epi16(shuffle, m2); + p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6); + shuffle = _mm_add_epi16(shuffle, m2); + p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7); + + p0 = _mm_add_epi32(p0, p1); + p2 = _mm_add_epi32(p2, p3); + p4 = _mm_add_epi32(p4, p5); + p6 = _mm_add_epi32(p6, p7); + + // isum in 32bits*4*2 + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6)); + } + + // sumf += dall * isum - dmin * summs in 32bits + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc); + } + + *s = hsum_float_8(acc); + +#else + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + int summs = 0; + for (int j = 0; j < 16; ++j) { + summs += y[i].bsums[j] * (sc[j] >> 4); + } + + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + int isum = 0; + int is = 0; + int d; + for (int k = 0; k < QK_K/128; ++k) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + d = sc[is++] & 0xF; + int isuml = 0; + for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + d = sc[is++] & 0xF; + isuml = 0; + for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + shift += 2; + q8 += 32; + } + q2 += 32; + } + sumf += dall * isum - dmin * summs; + } + *s = sumf; +#endif +} + +void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + const __m256i mone = _mm256_set1_epi8(1); + const __m128i m32 = _mm_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + uint32_t aux[3]; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + // Set up scales + memcpy(aux, x[i].scales, 12); + __m128i scales128 = _mm_set_epi32( + ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), + ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), + (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), + (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); + scales128 = _mm_sub_epi8(scales128, m32); + const __m256i all_scales = _mm256_cvtepi8_epi16(scales128); + const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); + const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); + const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; + + // high bit + const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask); + + // integer accumulator + __m256i sumi = _mm256_setzero_si256(); + + int bit = 0; + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + // load low 2 bits + const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32; + + // prepare low and high bits + const __m256i q3l_0 = _mm256_and_si256(q3bits, m3); + const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3); + const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3); + const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3); + const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + // load Q8 quants + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); + __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2); + __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3); + + __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); + __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2); + __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + p16_2 = _mm256_sub_epi16(p16_2, q8s_2); + p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + + // multiply with scales + p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1); + p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2); + p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3); + + // accumulate + p16_0 = _mm256_add_epi32(p16_0, p16_1); + p16_2 = _mm256_add_epi32(p16_2, p16_3); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2)); + + } + + // multiply with block scale and accumulate + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(3); + const __m128i mone = _mm_set1_epi8(1); + const __m128i m32 = _mm_set1_epi8(32); + const __m128i m2 = _mm_set1_epi8(2); + + __m256 acc = _mm256_setzero_ps(); + + const uint32_t *aux; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + // Set up scales + aux = (const uint32_t *)x[i].scales; + __m128i scales128 = _mm_set_epi32( + ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), + ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), + (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), + (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); + scales128 = _mm_sub_epi8(scales128, m32); + const __m128i scales_0 = _mm_cvtepi8_epi16(scales128); + const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128)); + const __m128i scales[2] = { scales_0, scales_1 }; + + // high bit *128*2 from block_q3_K.hmask[QK_K/8] + const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]); + const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]); + + // integer accumulator + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + for (int j = 0; j < QK_K/128; ++j) { + // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4] + const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; + const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; + + // prepare low and high bits + const int bit = j << 2; + + const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3); + const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3); + const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2); + const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2); + + const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3); + const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3); + const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2); + const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2); + + const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3); + const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3); + const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2); + const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2); + + const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3); + const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3); + const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2); + const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2); + + // load Q8 quants from block_q8_K.qs[QK_K] + const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0); + __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1); + __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2); + __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3); + __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4); + __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5); + __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6); + __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7); + + __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0); + __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1); + __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2); + __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3); + __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4); + __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5); + __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6); + __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7); + + p16_0 = _mm_sub_epi16(p16_0, q8s_0); + p16_1 = _mm_sub_epi16(p16_1, q8s_1); + p16_2 = _mm_sub_epi16(p16_2, q8s_2); + p16_3 = _mm_sub_epi16(p16_3, q8s_3); + p16_4 = _mm_sub_epi16(p16_4, q8s_4); + p16_5 = _mm_sub_epi16(p16_5, q8s_5); + p16_6 = _mm_sub_epi16(p16_6, q8s_6); + p16_7 = _mm_sub_epi16(p16_7, q8s_7); + + // multiply with scales + __m128i shuffle = _mm_set1_epi16(0x0100); + p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0); + shuffle = _mm_add_epi16(shuffle, m2); + p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1); + shuffle = _mm_add_epi16(shuffle, m2); + p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2); + shuffle = _mm_add_epi16(shuffle, m2); + p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3); + shuffle = _mm_add_epi16(shuffle, m2); + p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4); + shuffle = _mm_add_epi16(shuffle, m2); + p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5); + shuffle = _mm_add_epi16(shuffle, m2); + p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6); + shuffle = _mm_add_epi16(shuffle, m2); + p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7); + + // accumulate + p16_0 = _mm_add_epi32(p16_0, p16_1); + p16_2 = _mm_add_epi32(p16_2, p16_3); + p16_4 = _mm_add_epi32(p16_4, p16_5); + p16_6 = _mm_add_epi32(p16_6, p16_7); + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6)); + + } + + // multiply with block scale and accumulate + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc); + + } + + *s = hsum_float_8(acc); + +#else + // scalar version + // This function is written like this so the compiler can manage to vectorize most of it + // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the + // manually vectorized version above. Every other version I tried would run at least 4 times slower. + // The ideal situation would be if we could just write the code once, and the compiler would + // automatically produce the best possible set of machine instructions, instead of us having to manually + // write vectorized versions for AVX, ARM_NEON, etc. + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + uint32_t auxs[4]; + const int8_t * scales = (const int8_t*)auxs; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT hm = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + q3 += 32; + } + a = aux8; + + memcpy(auxs, x[i].scales, 12); + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; + +#endif + +} + +void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#if defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + __m128 acc_m = _mm_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); + + const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); + acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m); + + const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); + const __m256i scales = MM256_SET_M128I(sc128, sc128); + + __m256i sumi = _mm256_setzero_si256(); + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); + const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); + + const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4l = _mm256_and_si256(q4bits, m4); + const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); + + const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); + p16l = _mm256_madd_epi16(scale_l, p16l); + + const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); + p16h = _mm256_madd_epi16(scale_h, p16h); + const __m256i sumj = _mm256_add_epi32(p16l, p16h); + + sumi = _mm256_add_epi32(sumi, sumj); + } + + __m256 vd = _mm256_set1_ps(d); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); + + } + + acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); + acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); + + *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i m2 = _mm_set1_epi8(0x2); + + __m256 acc = _mm256_setzero_ps(); + __m128 acc_m = _mm_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); + const __m128i scales = _mm_cvtepu8_epi16(utmps); + const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); + + const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); + const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); + const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); + const __m128i prod = _mm_madd_epi16(mins, q8s); + acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m); + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + __m128i shuffle = _mm_set1_epi16(0x0100); + for (int j = 0; j < QK_K/64; ++j) { + + const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + + __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4l_0 = _mm_and_si128(q4bits, m4); + const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); + q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4l_1 = _mm_and_si128(q4bits, m4); + const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); + + const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0); + p16l = _mm_madd_epi16(scale_l, p16l); + sumi_0 = _mm_add_epi32(sumi_0, p16l); + const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + p16l = _mm_maddubs_epi16(q4l_1, q8l_1); + p16l = _mm_madd_epi16(scale_l, p16l); + sumi_1 = _mm_add_epi32(sumi_1, p16l); + + const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0); + p16h = _mm_madd_epi16(scale_h, p16h); + sumi_0 = _mm_add_epi32(sumi_0, p16h); + const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + p16h = _mm_maddubs_epi16(q4h_1, q8h_1); + p16h = _mm_madd_epi16(scale_h, p16h); + sumi_1 = _mm_add_epi32(sumi_1, p16h); + + } + + __m256 vd = _mm256_set1_ps(d); + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); + + } + + acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); + acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); + + *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); + +#else + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + a += 32; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + a += 32; q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#if defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m128i mzero = _mm_setzero_si128(); + const __m256i mone = _mm256_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0.f; + + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q5 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); + + const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); + const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); + summs += dmin * _mm_extract_epi32(hsum, 0); + + const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); + const __m256i scales = MM256_SET_M128I(sc128, sc128); + + const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh); + __m256i hmask = mone; + + __m256i sumi = _mm256_setzero_si256(); + + int bit = 0; + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); + const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); + + const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32; + + const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); + const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); + const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0); + hmask = _mm256_slli_epi16(hmask, 1); + + const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); + const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); + const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1); + hmask = _mm256_slli_epi16(hmask, 1); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1); + + p16_0 = _mm256_madd_epi16(scale_0, p16_0); + p16_1 = _mm256_madd_epi16(scale_1, p16_1); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + + } + + __m256 vd = _mm256_set1_ps(d); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc) + summs; + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i mzero = _mm_setzero_si128(); + const __m128i mone = _mm_set1_epi8(1); + const __m128i m2 = _mm_set1_epi8(2); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0.f; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * GGML_RESTRICT q5 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); + const __m128i scales = _mm_cvtepu8_epi16(utmps); + const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); + + const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); + const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); + const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); + const __m128i prod = _mm_madd_epi16(mins, q8s); + const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); + summs += dmin * _mm_extract_epi32(hsum, 0); + + const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]); + const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]); + __m128i hmask = mone; + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + int bit = 0; + + __m128i shuffle = _mm_set1_epi16(0x0100); + for (int j = 0; j < QK_K/64; ++j) { + + const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + + const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; + const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; + + __m128i q5l_0 = _mm_and_si128(q5bits_0, m4); + __m128i q5l_1 = _mm_and_si128(q5bits_1, m4); + __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); + __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); + __m128i q5_0 = _mm_add_epi8(q5l_0, q5h_0); + __m128i q5_1 = _mm_add_epi8(q5l_1, q5h_1); + hmask = _mm_slli_epi16(hmask, 1); + + __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0); + __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1); + p16_0 = _mm_madd_epi16(scale_0, p16_0); + p16_1 = _mm_madd_epi16(scale_0, p16_1); + + q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4); + q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4); + q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); + q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); + q5_0 = _mm_add_epi8(q5l_0, q5h_0); + q5_1 = _mm_add_epi8(q5l_1, q5h_1); + hmask = _mm_slli_epi16(hmask, 1); + + q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0); + __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1); + p16_2 = _mm_madd_epi16(scale_1, p16_2); + p16_3 = _mm_madd_epi16(scale_1, p16_3); + + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); + + } + + __m256 vd = _mm256_set1_ps(d); + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); + + } + + *s = hsum_float_8(acc) + summs; + +#else + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const uint8_t * GGML_RESTRICT hm = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i m2 = _mm256_set1_epi8(3); + const __m256i m32s = _mm256_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * GGML_RESTRICT q4 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); + + __m256i sumi = _mm256_setzero_si256(); + + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); + const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); + const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); + const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); + is += 4; + + const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32; + + const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4); + const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4); + const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4); + const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4); + + const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); + const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1); + const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2); + const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); + __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2); + __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3); + + __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); + __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2); + __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + p16_2 = _mm256_sub_epi16(p16_2, q8s_2); + p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + + p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); + p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2); + p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3)); + + } + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(3); + const __m128i m15 = _mm_set1_epi8(15); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * GGML_RESTRICT q4 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + // handle the q6_k -32 offset separately using bsums + const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)y[i].bsums); + const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)y[i].bsums + 1); + const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales); + const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8)); + const __m128i q8sclsub_0 = _mm_slli_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), 5); + const __m128i q8sclsub_1 = _mm_slli_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), 5); + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16; + const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16; + + const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4); + const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4); + const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12)), 2); + const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12)), 2); + const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48)); + const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48)); + const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64)), 2); + const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64)), 2); + + const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + + const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0); + const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1); + const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2); + const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3); + const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4); + const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5); + const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6); + const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7); + + const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + + __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0); + __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1); + __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2); + __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3); + __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4); + __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5); + __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6); + __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7); + + const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); + const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); + const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); + const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); + is += 4; + + p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_0, 8)), p16_1); + p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); + p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_1, 8)), p16_3); + p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4); + p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_2, 8)), p16_5); + p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6); + p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_3, 8)), p16_7); + + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7)); + + } + + sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0); + sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1); + const __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi)), acc); + } + + *s = hsum_float_8(acc); + +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q4 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + a += 128; + q4 += 64; + qh += 32; + } + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/16; ++j) { + int scale = x[i].scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +#if defined (__AVX__) || defined (__AVX2__) +static const int8_t keven_signs_q2xs[1024] = { + 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, + 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, + 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1, + 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1, + 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, + 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1, + 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1, + 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, + 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1, + 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1, + 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1, + 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1, + 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1, + 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1, + 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1, + 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1, + 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1, + 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1, + 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1, + 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1, + 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1, + 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1, + 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, + 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1, + 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1, + 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1, + 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, + 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1, + 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1, + 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1, + 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1, + 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, +}; +#endif + +void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__AVX2__) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * GGML_RESTRICT q2 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; + const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); + const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); + const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], + signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127], + signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); + const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1); + const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2); + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); + const uint16_t ls1 = aux32[1] >> 28; + const uint16_t ls2 = aux32[3] >> 28; + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); + sumi1 = _mm256_add_epi32(sumi1, p1); + sumi2 = _mm256_add_epi32(sumi2, p2); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + +#elif defined(__AVX__) + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * GGML_RESTRICT q2 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; + const __m128i q2_1_0 = _mm_set_epi64x(iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); + const __m128i q2_1_1 = _mm_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]]); + const __m128i q2_2_0 = _mm_set_epi64x(iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); + const __m128i q2_2_1 = _mm_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]]); + const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]); + const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); + const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127]); + const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0); + const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1); + const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0); + const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1); + const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); + const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); + const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); + const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); + const uint16_t ls1 = aux32[1] >> 28; + const uint16_t ls2 = aux32[3] >> 28; + const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1)); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1)); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1)); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1)); + sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); + sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); + sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); + sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); + } + + accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + +#else + + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * GGML_RESTRICT q2 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(aux32, q2, 2*sizeof(uint32_t)); + q2 += 4; + const uint32_t ls = 2*(aux32[1] >> 28) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); + const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls; + } + sumf += d * bsum; + } + *s = 0.125f * sumf; +#endif +} + +void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__AVX2__) + + const __m256i mone = _mm256_set1_epi8(1); + static const char block_sign_shuffle_mask_1[32] = { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, + 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, + }; + static const char block_sign_shuffle_mask_2[32] = { + 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, + 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, + }; + static const uint8_t bit_selector_mask_bytes[32] = { + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes); + const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1); + const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2); + + static const uint8_t k_bit_helper[32] = { + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + }; + const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper); + const __m256i m511 = _mm256_set1_epi16(511); + const __m128i m4 = _mm_set1_epi8(0xf); + const __m128i m1 = _mm_set1_epi8(1); + + uint64_t aux64; + + // somewhat hacky, but gives a significant boost in performance + __m256i aux_gindex; + const uint16_t * gindex = (const uint16_t *)&aux_gindex; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * GGML_RESTRICT q2 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + memcpy(&aux64, x[i].scales, 8); + __m128i stmp = _mm_set1_epi64x(aux64); + stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4)); + const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1); + + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) { + + const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2); q2 += 16; + aux_gindex = _mm256_and_si256(q2_data, m511); + + const __m256i partial_sign_bits = _mm256_srli_epi16(q2_data, 9); + const __m256i partial_sign_bits_upper = _mm256_srli_epi16(q2_data, 13); + const __m256i partial_sign_bits_for_counting = _mm256_xor_si256(partial_sign_bits, partial_sign_bits_upper); + + const __m256i odd_bits = _mm256_shuffle_epi8(bit_helper, partial_sign_bits_for_counting); + const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, odd_bits); + + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + + const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]], + iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]); + const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]], + iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]); + const __m256i q2_3 = _mm256_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]], + iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]); + const __m256i q2_4 = _mm256_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]], + iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]); + + const __m128i full_signs_l = _mm256_castsi256_si128(full_sign_bits); + const __m128i full_signs_h = _mm256_extractf128_si256(full_sign_bits, 1); + const __m256i full_signs_1 = MM256_SET_M128I(full_signs_l, full_signs_l); + const __m256i full_signs_2 = MM256_SET_M128I(full_signs_h, full_signs_h); + + __m256i signs; + signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1); + signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone)); + + signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_2); + signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone)); + + signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_1); + signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone)); + + signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_2); + signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); + const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone)); + + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); + const __m256i dot3 = _mm256_maddubs_epi16(q2_3, q8s_3); + const __m256i dot4 = _mm256_maddubs_epi16(q2_4, q8s_4); + + const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0))); + const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1))); + const __m256i sc3 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2))); + const __m256i sc4 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3))); + + sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1)); + sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2)); + sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot3, sc3)); + sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot4, sc4)); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + +#elif defined(__AVX__) + const __m128i mone = _mm_set1_epi8(1); + static const char block_sign_shuffle_mask_1[32] = { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, + 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, + }; + static const char block_sign_shuffle_mask_2[32] = { + 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, + 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, + }; + static const uint8_t bit_selector_mask_bytes[32] = { + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m128i bit_selector_mask_0 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes); + const __m128i bit_selector_mask_1 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes + 1); + const __m128i block_sign_shuffle_1_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1); + const __m128i block_sign_shuffle_1_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1 + 1); + const __m128i block_sign_shuffle_2_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2); + const __m128i block_sign_shuffle_2_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2 + 1); + + static const uint8_t k_bit_helper[32] = { + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, + }; + const __m128i bit_helper_0 = _mm_loadu_si128((const __m128i*)k_bit_helper); + const __m128i bit_helper_1 = _mm_loadu_si128((const __m128i*)k_bit_helper + 1); + const __m128i m511 = _mm_set1_epi16(511); + const __m128i m4 = _mm_set1_epi8(0xf); + const __m128i m1 = _mm_set1_epi8(1); + + uint64_t aux64; + + // somewhat hacky, but gives a significant boost in performance + __m256i aux_gindex; + const uint16_t * gindex = (const uint16_t *)&aux_gindex; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * GGML_RESTRICT q2 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + memcpy(&aux64, x[i].scales, 8); + __m128i stmp = _mm_set1_epi64x(aux64); + stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4)); + const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1); + + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) { + + const __m128i q2_data_0 = _mm_loadu_si128((const __m128i*)q2); + const __m128i q2_data_1 = _mm_loadu_si128((const __m128i*)q2 + 1); q2 += 16; + aux_gindex = MM256_SET_M128I(_mm_and_si128(q2_data_1, m511), _mm_and_si128(q2_data_0, m511)); + + const __m128i partial_sign_bits_0 = _mm_srli_epi16(q2_data_0, 9); + const __m128i partial_sign_bits_1 = _mm_srli_epi16(q2_data_1, 9); + const __m128i partial_sign_bits_upper_0 = _mm_srli_epi16(q2_data_0, 13); + const __m128i partial_sign_bits_upper_1 = _mm_srli_epi16(q2_data_1, 13); + const __m128i partial_sign_bits_for_counting_0 = _mm_xor_si128(partial_sign_bits_0, partial_sign_bits_upper_0); + const __m128i partial_sign_bits_for_counting_1 = _mm_xor_si128(partial_sign_bits_1, partial_sign_bits_upper_1); + + const __m128i odd_bits_0 = _mm_shuffle_epi8(bit_helper_0, partial_sign_bits_for_counting_0); + const __m128i odd_bits_1 = _mm_shuffle_epi8(bit_helper_1, partial_sign_bits_for_counting_1); + const __m128i full_sign_bits_0 = _mm_or_si128(partial_sign_bits_0, odd_bits_0); + const __m128i full_sign_bits_1 = _mm_or_si128(partial_sign_bits_1, odd_bits_1); + + const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_3_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_3_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_4_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_4_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + + const __m128i q2_1_0 = _mm_set_epi64x(iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]); + const __m128i q2_1_1 = _mm_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]]); + const __m128i q2_2_0 = _mm_set_epi64x(iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]); + const __m128i q2_2_1 = _mm_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]]); + const __m128i q2_3_0 = _mm_set_epi64x(iq2xs_grid[gindex[9]], iq2xs_grid[gindex[8]]); + const __m128i q2_3_1 = _mm_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]]); + const __m128i q2_4_0 = _mm_set_epi64x(iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]); + const __m128i q2_4_1 = _mm_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]]); + + // AVX2 full_signs_1 is full_sign_bits_0 here + // AVX2 full_signs_2 is full_sign_bits_1 here + __m128i signs_0, signs_1; + signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_0); + signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_1); + signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); + signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); + const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, _mm_or_si128(signs_0, mone)); + const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, _mm_or_si128(signs_1, mone)); + + signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_0); + signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_1); + signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); + signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); + const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, _mm_or_si128(signs_0, mone)); + const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, _mm_or_si128(signs_1, mone)); + + signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_0); + signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_1); + signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); + signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); + const __m128i q8s_3_0 = _mm_sign_epi8(q8_3_0, _mm_or_si128(signs_0, mone)); + const __m128i q8s_3_1 = _mm_sign_epi8(q8_3_1, _mm_or_si128(signs_1, mone)); + + signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_0); + signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_1); + signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); + signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); + const __m128i q8s_4_0 = _mm_sign_epi8(q8_4_0, _mm_or_si128(signs_0, mone)); + const __m128i q8s_4_1 = _mm_sign_epi8(q8_4_1, _mm_or_si128(signs_1, mone)); + + const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); + const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); + const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); + const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); + const __m128i dot3_0 = _mm_maddubs_epi16(q2_3_0, q8s_3_0); + const __m128i dot3_1 = _mm_maddubs_epi16(q2_3_1, q8s_3_1); + const __m128i dot4_0 = _mm_maddubs_epi16(q2_4_0, q8s_4_0); + const __m128i dot4_1 = _mm_maddubs_epi16(q2_4_1, q8s_4_1); + + __m128i sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)); + const __m128i sc1_0 = _mm_cvtepi8_epi16(sc_tmp); + const __m128i sc1_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); + sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)); + const __m128i sc2_0 = _mm_cvtepi8_epi16(sc_tmp); + const __m128i sc2_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); + sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2)); + const __m128i sc3_0 = _mm_cvtepi8_epi16(sc_tmp); + const __m128i sc3_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); + sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3)); + const __m128i sc4_0 = _mm_cvtepi8_epi16(sc_tmp); + const __m128i sc4_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); + + sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot1_0, sc1_0)); + sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot1_1, sc1_1)); + sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot2_0, sc2_0)); + sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot2_1, sc2_1)); + sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot3_0, sc3_0)); + sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot3_1, sc3_1)); + sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot4_0, sc4_0)); + sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot4_1, sc4_1)); + } + + accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + +#else + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * GGML_RESTRICT q2 = x[i].qs; + const uint8_t * GGML_RESTRICT sc = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1; + const uint16_t ls2 = 2*(sc[ib32] >> 4) + 1; + int32_t sumi = 0; + for (int l = 0; l < 2; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); + const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls1; + sumi = 0; + for (int l = 2; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); + const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls2; + q2 += 4; + } + sumf += d * bsum; + } + *s = 0.125f * sumf; +#endif +} + +void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__AVX2__) + + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m128i m4 = _mm_set1_epi8(0xf); + const __m128i m1 = _mm_set1_epi8(1); + + const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1); + const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2); + + uint64_t aux64; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8); + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + memcpy(&aux64, x[i].scales, 8); + const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1); + const __m256i scales16 = _mm256_cvtepi8_epi16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15 + + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q2_1 = _mm256_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)], + iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)], + iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)], + iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]); + const __m256i q2_2 = _mm256_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)], + iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)], + iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)], + iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]); + qs += 8; + + __m256i aux256 = _mm256_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); + const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2); + const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1); + + aux256 = _mm256_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); + const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2); + const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2); + + signs += 4; + + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1 + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3 + + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+0))); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+1))); + sumi1 = _mm256_add_epi32(sumi1, p1); + sumi2 = _mm256_add_epi32(sumi2, p2); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + +#elif defined(__AVX__) + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m128i m4 = _mm_set1_epi8(0xf); + const __m128i m1 = _mm_set1_epi8(1); + + const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1); + const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1); + const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2); + const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1); + + uint64_t aux64; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8); + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + memcpy(&aux64, x[i].scales, 8); + const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1); + const __m128i scales16_0 = _mm_cvtepi8_epi16(scales8); + const __m128i scales16_1 = _mm_cvtepi8_epi16(_mm_srli_si128(scales8, 8)); + + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q2_1_0 = _mm_set_epi64x(iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)], + iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]); + const __m128i q2_1_1 = _mm_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)], + iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)]); + const __m128i q2_2_0 = _mm_set_epi64x(iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)], + iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]); + const __m128i q2_2_1 = _mm_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)], + iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)]); + qs += 8; + + __m128i aux128_0 = _mm_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16)); + __m128i aux128_1 = aux128_0; + aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); + aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); + const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); + const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); + const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0); + const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1); + + aux128_0 = _mm_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16)); + aux128_1 = aux128_0; + aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); + aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); + const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); + const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); + const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0); + const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1); + + signs += 4; + + const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); + const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); + const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); + const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); + + const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 0))); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 1))); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 0))); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 1))); + sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); + sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); + sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); + sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); + } + + accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); + + } + + *s = 0.125f * hsum_float_8(accumf); + +#else + + float sumf = 0; + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint8_t * signs = qs + QK_K/8; + + int bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + int ls1 = 1 + 2*(x[i].scales[ib32] & 0xf); + int ls2 = 1 + 2*(x[i].scales[ib32] >> 4); + int sumi1 = 0, sumi2 = 0; + for (int l = 0; l < 2; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); + for (int j = 0; j < 8; ++j) { + sumi1 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + for (int l = 2; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); + for (int j = 0; j < 8; ++j) { + sumi2 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += ls1 * sumi1 + ls2 * sumi2; + qs += 4; + signs += 4; + } + + sumf += d * bsum; + } + + *s = 0.125f * sumf; + +#endif + +} + +void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq3_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__AVX2__) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[2]; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], + iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + q3 += 8; + const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], + iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + q3 += 8; + memcpy(aux32, gas, 8); gas += 8; + const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127], + signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]); + const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], + signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1); + const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2); + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); + const uint16_t ls1 = aux32[0] >> 28; + const uint16_t ls2 = aux32[1] >> 28; + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); + sumi1 = _mm256_add_epi32(sumi1, p1); + sumi2 = _mm256_add_epi32(sumi2, p2); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = 0.25f * hsum_float_8(accumf); + +#elif defined(__AVX__) + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[2]; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q2_1_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + const __m128i q2_1_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]); + q3 += 8; + const __m128i q2_2_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); + const __m128i q2_2_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]); + q3 += 8; + memcpy(aux32, gas, 8); gas += 8; + const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]); + const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127]); + const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); + const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]); + const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0); + const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1); + const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0); + const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1); + const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); + const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); + const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); + const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); + const uint16_t ls1 = aux32[0] >> 28; + const uint16_t ls2 = aux32[1] >> 28; + const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1)); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1)); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1)); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1)); + sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); + sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); + sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); + sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); + } + + accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); + + } + + *s = 0.25f * hsum_float_8(accumf); + +#else + + uint32_t aux32; + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t); + const uint32_t ls = 2*(aux32 >> 28) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]); + const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]); + const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127]; + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + q3 += 8; + bsum += sumi * ls; + } + sumf += d * bsum; + } + *s = 0.25f * sumf; +#endif +} + +void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq3_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined(__AVX2__) + + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1); + const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2); + + const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8); + const __m256i idx_mask = _mm256_set1_epi32(256); + + typedef union { + __m256i vec[2]; + uint32_t index[16]; + } index_t; + + index_t idx; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs)); qs += 16; + idx.vec[0] = _mm256_set1_epi32(qh[ib32+0]); + idx.vec[1] = _mm256_set1_epi32(qh[ib32+1]); + idx.vec[0] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[0], idx_shift), idx_mask); + idx.vec[1] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[1], idx_shift), idx_mask); + idx.vec[0] = _mm256_or_si256(idx.vec[0], _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l))); + idx.vec[1] = _mm256_or_si256(idx.vec[1], _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l, 1))); + + // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange. + //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4); + //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4); + const __m256i q2_1 = _mm256_set_epi32( + iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]], + iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]] + ); + const __m256i q2_2 = _mm256_set_epi32( + iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]], + iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]] + ); + + __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); + const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2); + const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1); + + aux256 = _mm256_set1_epi32(signs[2] | (signs[3] << 16)); + aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); + const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2); + const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2); + + signs += 4; + + const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); + const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); + const uint16_t ls1 = x[i].scales[ib32/2] & 0xf; + const uint16_t ls2 = x[i].scales[ib32/2] >> 4; + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); + sumi1 = _mm256_add_epi32(sumi1, p1); + sumi2 = _mm256_add_epi32(sumi2, p2); + } + + accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); + + } + + *s = hsum_float_8(accumf); + +#elif defined(__AVX__) + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 + }; + + static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, + }; + + const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1); + const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1); + const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2); + const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1); + + const __m128i idx_mul_0 = _mm_set_epi32(32, 64, 128, 256); + const __m128i idx_mul_1 = _mm_set_epi32(2, 4, 8, 16); + const __m128i idx_mask = _mm_set1_epi32(256); + + typedef union { + __m128i vec[4]; + uint32_t index[16]; + } index_t; + + index_t idx; + + __m256 accumf = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i qs_tmp = _mm_loadu_si128((const __m128i *)qs); + const __m128i idx_l_0 = _mm_cvtepu8_epi16(qs_tmp); + const __m128i idx_l_1 = _mm_cvtepu8_epi16(_mm_srli_si128(qs_tmp, 8)); qs += 16; + idx.vec[0] = _mm_set1_epi32(qh[ib32+0]); + idx.vec[1] = idx.vec[0]; + idx.vec[2] = _mm_set1_epi32(qh[ib32+1]); + idx.vec[3] = idx.vec[2]; + + idx.vec[0] = _mm_and_si128(_mm_mullo_epi32(idx.vec[0], idx_mul_0), idx_mask); + idx.vec[1] = _mm_and_si128(_mm_mullo_epi32(idx.vec[1], idx_mul_1), idx_mask); + idx.vec[2] = _mm_and_si128(_mm_mullo_epi32(idx.vec[2], idx_mul_0), idx_mask); + idx.vec[3] = _mm_and_si128(_mm_mullo_epi32(idx.vec[3], idx_mul_1), idx_mask); + + idx.vec[0] = _mm_or_si128(idx.vec[0], _mm_cvtepi16_epi32(idx_l_0)); + idx.vec[1] = _mm_or_si128(idx.vec[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8))); + idx.vec[2] = _mm_or_si128(idx.vec[2], _mm_cvtepi16_epi32(idx_l_1)); + idx.vec[3] = _mm_or_si128(idx.vec[3], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8))); + + const __m128i q2_1_0 = _mm_set_epi32(iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]); + const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]); + const __m128i q2_2_0 = _mm_set_epi32(iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[9]], iq3s_grid[idx.index[8]]); + const __m128i q2_2_1 = _mm_set_epi32(iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]]); + + __m128i aux128_0 = _mm_set1_epi32(signs[0] | (signs[1] << 16)); + __m128i aux128_1 = aux128_0; + aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); + aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); + const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); + const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); + const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0); + const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1); + + aux128_0 = _mm_set1_epi32(signs[2] | (signs[3] << 16)); + aux128_1 = aux128_0; + aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); + aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); + const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); + const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); + const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0); + const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1); + + signs += 4; + + const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); + const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); + const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); + const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); + const uint16_t ls1 = x[i].scales[ib32/2] & 0xf; + const uint16_t ls2 = x[i].scales[ib32/2] >> 4; + const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1)); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1)); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1)); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1)); + sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); + sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); + sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); + sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); + } + + accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); + + } + + *s = hsum_float_8(accumf); + +#else + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT signs = x[i].signs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1; + const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + qs += 8; + signs += 4; + bsum += sumi * ls1; + sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + qs += 8; + signs += 4; + bsum += sumi * ls2; + } + sumf += d * bsum; + } + *s = sumf; +#endif +} + +#if defined(__AVX2__) +static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { + const __m256i ax = _mm256_sign_epi8(x, x); + const __m256i sy = _mm256_sign_epi8(y, x); + return _mm256_maddubs_epi16(ax, sy); +} +#endif + +void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined __AVX2__ + + __m256 accum = _mm256_setzero_ps(); + float accum1 = 0; + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + __m256i sumi = _mm256_setzero_si256(); + int sumi1 = 0; + for (int ib = 0; ib < QK_K/32; ib += 2) { +#ifdef __BMI2__ + const uint64_t packed_idx1 = _pdep_u64(*(const uint32_t *)qs, 0x00ff00ff00ff00ffULL) | _pdep_u64(qh[ib], 0x700070007000700ULL); + const uint64_t packed_idx2 = _pdep_u64(*(const uint32_t *)(qs + 4), 0x00ff00ff00ff00ffULL) | _pdep_u64(qh[ib + 1], 0x700070007000700ULL); + const uint16_t *idx1 = (const uint16_t *)(&packed_idx1); + const uint16_t *idx2 = (const uint16_t *)(&packed_idx2); + const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[idx1[3]], iq1s_grid[idx1[2]], iq1s_grid[idx1[1]], iq1s_grid[idx1[0]]); + const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[idx2[3]], iq1s_grid[idx2[2]], iq1s_grid[idx2[1]], iq1s_grid[idx2[0]]); +#else + const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)], + iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]); + const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)], + iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]); +#endif + qs += 8; + const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1); + const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2); + const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; + const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; + const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(ls1)); + const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(ls2)); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p1, p2)); + sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1 + + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2; + } + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + accum = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi), accum); + accum1 += d * sumi1; + + } + + *s = hsum_float_8(accum) + IQ1S_DELTA * accum1; + +#elif defined __AVX__ + __m256 accum = _mm256_setzero_ps(); + float accum1 = 0; + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + int sumi1 = 0; + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m128i q1b_1_0 = _mm_set_epi64x(iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]); + const __m128i q1b_1_1 = _mm_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)]); + const __m128i q1b_2_0 = _mm_set_epi64x(iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]); + const __m128i q1b_2_1 = _mm_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)]); + qs += 8; + const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + + const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0); + const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1); + const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0); + const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1); + const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; + const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; + const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(ls1)); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(ls1)); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(ls2)); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(ls2)); + + sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0)); + sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1)); + sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1 + + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2; + } + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum); + accum1 += d * sumi1; + + } + + *s = hsum_float_8(accum) + IQ1S_DELTA * accum1; + +#else + + float sumf = 0; + for (int i = 0; i < nb; i++) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + int sumi = 0, sumi1 = 0; + for (int ib = 0; ib < QK_K/32; ++ib) { + const int ls = 2*((qh[ib] >> 12) & 7) + 1; + const int delta = qh[ib] & 0x8000 ? -1 : 1; + int lsum = 0; + for (int l = 0; l < 4; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8))); + for (int j = 0; j < 8; ++j) { + lsum += q8[j] * grid[j]; + } + q8 += 8; + } + sumi += ls * lsum; + sumi1 += ls * delta * (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]); + qs += 4; + } + + sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); + } + + *s = sumf; + +#endif +} + +void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_m * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + iq1m_scale_t scale; + +#if defined __AVX2__ + + const __m256i mask = _mm256_set1_epi16(0x7); + const __m256i mone = _mm256_set1_epi16(1); + const __m256i mone8 = _mm256_set1_epi8(1); + const __m256i mtwo8 = _mm256_set1_epi8(2); + // VPSHUFB cannot cross 128-bit lanes so odd shifts go to upper half. + const __m256i scales_shift = _mm256_set_epi64x(9, 3, 6, 0); + + __m256 accum1 = _mm256_setzero_ps(); + __m256 accum2 = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + // Extract 3-bit scales (16 values) + __m256i scales = _mm256_set1_epi64x(*(const uint64_t*)sc); + scales = _mm256_srlv_epi64(scales, scales_shift); + scales = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scales, mask), 1), mone); + + // Indices to repeat each scale 8 times. + __m256i scales_idx1 = _mm256_set1_epi16(0x0100); + __m256i scales_idx2 = _mm256_add_epi8(scales_idx1, _mm256_set1_epi8(8)); + + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib = 0; ib < QK_K/32; ib += 2) { +#ifdef __BMI2__ + const uint64_t packed_idx1 = _pdep_u64(*(const uint32_t *)qs, 0x00ff00ff00ff00ffULL) + | _pdep_u64(*(const uint16_t*)(qh) & 0x7777, 0xf000f000f000f00ULL); + const uint64_t packed_idx2 = _pdep_u64(*(const uint32_t *)(qs + 4), 0x00ff00ff00ff00ffULL) + | _pdep_u64(*(const uint16_t*)(qh + 2) & 0x7777, 0xf000f000f000f00ULL); + const uint16_t *idx1 = (const uint16_t *)(&packed_idx1); + const uint16_t *idx2 = (const uint16_t *)(&packed_idx2); + const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[idx1[3]], iq1s_grid[idx1[2]], iq1s_grid[idx1[1]], iq1s_grid[idx1[0]]); + const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[idx2[3]], iq1s_grid[idx2[2]], iq1s_grid[idx2[1]], iq1s_grid[idx2[0]]); + + // Convert signs to bytes 0x81 (negative) or 0x01 (positive) + const uint64_t delta_sign = _pdep_u64(*(const uint32_t*)(qh) & 0x88888888, 0xf0f0f0f0f0f0f0f0ULL); + const __m256i delta1 = _mm256_or_si256(mone8, _mm256_cvtepi8_epi64(_mm_set1_epi32(delta_sign))); + const __m256i delta2 = _mm256_or_si256(mone8, _mm256_cvtepi8_epi64(_mm_set1_epi32(delta_sign >> 32))); +#else + const __m256i q1b_1 = _mm256_set_epi64x( + iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)], + iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)] + ); + const __m256i q1b_2 = _mm256_set_epi64x( + iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)], + iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)] + ); + + const __m256i delta1 = _mm256_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101, + qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + const __m256i delta2 = _mm256_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101, + qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); +#endif + const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1); + const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2); + const __m256i dot3 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_1, delta1)); + const __m256i dot4 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_2, delta2)); + + __m256i scale1 = _mm256_shuffle_epi8(scales, scales_idx1); + __m256i scale2 = _mm256_shuffle_epi8(scales, scales_idx2); + + scales_idx1 = _mm256_add_epi8(scales_idx1, mtwo8); + scales_idx2 = _mm256_add_epi8(scales_idx2, mtwo8); + + const __m256i p1 = _mm256_madd_epi16(dot1, scale1); + const __m256i p2 = _mm256_madd_epi16(dot2, scale2); + const __m256i p3 = _mm256_madd_epi16(dot3, scale1); + const __m256i p4 = _mm256_madd_epi16(dot4, scale2); + + sumi1 = _mm256_add_epi32(sumi1, _mm256_add_epi32(p1, p2)); + sumi2 = _mm256_add_epi32(sumi2, _mm256_add_epi32(p3, p4)); + + qs += 8; qh += 4; + } + + const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16)); + + accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1); + accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2); + } + + *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2); + +#elif defined __AVX__ + const __m128i mask = _mm_set1_epi16(0x7); + const __m128i mone = _mm_set1_epi16(1); + + __m256 accum1 = _mm256_setzero_ps(); + __m256 accum2 = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m128i q1b_1_0 = _mm_set_epi64x( + iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]); + const __m128i q1b_1_1 = _mm_set_epi64x( + iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)]); + const __m128i q1b_2_0 = _mm_set_epi64x( + iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]); + const __m128i q1b_2_1 = _mm_set_epi64x( + iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)]); + const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + + const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0); + const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1); + const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0); + const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1); + + const __m128i delta1_0 = _mm_set_epi64x(qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + const __m128i delta1_1 = _mm_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + const __m128i delta2_0 = _mm_set_epi64x(qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + const __m128i delta2_1 = _mm_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, + qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); + + const __m128i dot3_0 = mul_add_epi8_sse(delta1_0, q8b_1_0); + const __m128i dot3_1 = mul_add_epi8_sse(delta1_1, q8b_1_1); + const __m128i dot4_0 = mul_add_epi8_sse(delta2_0, q8b_2_0); + const __m128i dot4_1 = mul_add_epi8_sse(delta2_1, q8b_2_1); + + __m128i scale1_0 = _mm_set1_epi16(sc[ib/2] >> 0); + __m128i scale1_1 = _mm_set1_epi16(sc[ib/2] >> 3); + __m128i scale2_0 = _mm_set1_epi16(sc[ib/2] >> 6); + __m128i scale2_1 = _mm_set1_epi16(sc[ib/2] >> 9); + + scale1_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_0, mask), 1), mone); + scale1_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_1, mask), 1), mone); + scale2_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_0, mask), 1), mone); + scale2_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_1, mask), 1), mone); + const __m128i p1_0 = _mm_madd_epi16(dot1_0, scale1_0); + const __m128i p1_1 = _mm_madd_epi16(dot1_1, scale1_1); + const __m128i p2_0 = _mm_madd_epi16(dot2_0, scale2_0); + const __m128i p2_1 = _mm_madd_epi16(dot2_1, scale2_1); + const __m128i p3_0 = _mm_madd_epi16(dot3_0, scale1_0); + const __m128i p3_1 = _mm_madd_epi16(dot3_1, scale1_1); + const __m128i p4_0 = _mm_madd_epi16(dot4_0, scale2_0); + const __m128i p4_1 = _mm_madd_epi16(dot4_1, scale2_1); + + sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0)); + sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1)); + sumi2_0 = _mm_add_epi32(sumi2_0, _mm_add_epi32(p3_0, p4_0)); + sumi2_1 = _mm_add_epi32(sumi2_1, _mm_add_epi32(p3_1, p4_1)); + + qs += 8; qh += 4; + } + + const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16)); + + accum1 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum1); + accum2 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi2_1, sumi2_0))), accum2); + } + + *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2); + +#else + + int sum1[2], sum2[2], delta[4]; + + float sumf = 0; + for (int i = 0; i < nb; i++) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + int sumi1 = 0, sumi2 = 0; + for (int ib = 0; ib < QK_K/32; ++ib) { + delta[0] = qh[0] & 0x08 ? -1 : 1; + delta[1] = qh[0] & 0x80 ? -1 : 1; + delta[2] = qh[1] & 0x08 ? -1 : 1; + delta[3] = qh[1] & 0x80 ? -1 : 1; + sum1[0] = sum1[1] = sum2[0] = sum2[1] = 0; + for (int l = 0; l < 4; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((uint16_t)qh[l/2] << (8 - 4*(l%2))) & 0x700))); + int lsum1 = 0, lsum2 = 0; + for (int j = 0; j < 8; ++j) { + lsum1 += q8[j] * grid[j]; + lsum2 += q8[j]; + } + q8 += 8; + sum1[l/2] += lsum1; + sum2[l/2] += lsum2*delta[l]; + } + + const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1; + const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1; + + sumi1 += sum1[0] * ls1 + sum1[1] * ls2; + sumi2 += sum2[0] * ls1 + sum2[1] * ls2; + qs += 4; + qh += 2; + } + + sumf += GGML_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2); + } + + *s = sumf; + +#endif +} + +void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK4_NL == 0); + static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same"); + + const block_iq4_nl * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + const int nb = n / QK4_NL; + + int ib = 0; + float sumf = 0; + +#if defined __AVX2__ + + const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); + const __m128i m4b = _mm_set1_epi8(0x0f); + const __m256i mone = _mm256_set1_epi16(1); + + __m256 accum1 = _mm256_setzero_ps(); + __m256 accum2 = _mm256_setzero_ps(); + for (; ib + 1 < nb; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs); + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs); + const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs); + const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs); + const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)), + _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b))); + const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)), + _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b))); + const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); + const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); + const __m256i p_1 = _mm256_madd_epi16(p16_1, mone); + const __m256i p_2 = _mm256_madd_epi16(p16_2, mone); + accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)), + _mm256_cvtepi32_ps(p_1), accum1); + accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)), + _mm256_cvtepi32_ps(p_2), accum2); + } + + sumf = hsum_float_8(_mm256_add_ps(accum1, accum2)); + +#elif defined __AVX__ + const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); + const __m128i m4b = _mm_set1_epi8(0x0f); + + __m256 accum = _mm256_setzero_ps(); + for (; ib + 1 < nb; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs); + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); + const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs); + const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1); + const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); + const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1); + + const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)); + const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)); + const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)); + const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)); + + const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1); + const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d); + accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum); + } + + sumf = hsum_float_8(accum); + +#endif + for (; ib < nb; ++ib) { + const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d); + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < QK4_NL/2; ++j) { + sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf]; + sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4]; + } + sumf += d * (sumi1 + sumi2); + } + *s = sumf; +} + +void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); + + const block_iq4_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + +#if defined __AVX2__ + + const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); + const __m128i m4b = _mm_set1_epi8(0x0f); + + __m256 accum = _mm256_setzero_ps(); + for (int ibl = 0; ibl < nb; ++ibl) { + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + uint16_t sh = x[ibl].scales_h; + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)qs); qs += 16; + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)qs); qs += 16; + const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; + const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)), + _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b))); + const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)), + _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b))); + const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); + const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); + const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32; + sh >>= 4; + const __m256i p_1 = _mm256_madd_epi16(p16_1, _mm256_set1_epi16(ls1)); + const __m256i p_2 = _mm256_madd_epi16(p16_2, _mm256_set1_epi16(ls2)); + sumi1 = _mm256_add_epi32(p_1, sumi1); + sumi2 = _mm256_add_epi32(p_2, sumi2); + } + accum = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d), + _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accum); + } + + *s = hsum_float_8(accum); + +#elif defined __AVX__ + const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); + const __m128i m4b = _mm_set1_epi8(0x0f); + + __m256 accum = _mm256_setzero_ps(); + for (int ibl = 0; ibl < nb; ++ibl) { + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + uint16_t sh = x[ibl].scales_h; + __m128i sumi1_0 = _mm_setzero_si128(); + __m128i sumi1_1 = _mm_setzero_si128(); + __m128i sumi2_0 = _mm_setzero_si128(); + __m128i sumi2_1 = _mm_setzero_si128(); + for (int ib = 0; ib < QK_K/32; ib += 2) { + const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)qs); qs += 16; + const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)qs); qs += 16; + const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; + const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)); + const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)); + const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)); + const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)); + const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0); + const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1); + const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0); + const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1); + const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32; + sh >>= 4; + const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, _mm_set1_epi16(ls1)); + const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, _mm_set1_epi16(ls1)); + const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, _mm_set1_epi16(ls2)); + const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, _mm_set1_epi16(ls2)); + sumi1_0 = _mm_add_epi32(p_1_0, sumi1_0); + sumi1_1 = _mm_add_epi32(p_1_1, sumi1_1); + sumi2_0 = _mm_add_epi32(p_2_0, sumi2_0); + sumi2_1 = _mm_add_epi32(p_2_1, sumi2_1); + } + __m128i sumi12_0 = _mm_add_epi32(sumi1_0, sumi2_0); + __m128i sumi12_1 = _mm_add_epi32(sumi1_1, sumi2_1); + accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d), + _mm256_cvtepi32_ps(MM256_SET_M128I(sumi12_1, sumi12_0))), accum); + } + + *s = hsum_float_8(accum); + +#else + float sumf = 0; + for (int ibl = 0; ibl < nb; ++ibl) { + const float d4d8 = GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d; + uint16_t h = x[ibl].scales_h; + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + for (int ib = 0; ib < QK_K/32; ib += 2) { + const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30); + const uint8_t ls2 = (x[ibl].scales_l[ib/2] >> 4) | ((h << 2) & 0x30); + h >>= 4; + const float d1 = d4d8*(ls1 - 32); + const float d2 = d4d8*(ls2 - 32); + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 16; ++j) { + sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; + sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; + } + sumf += d1 * (sumi1 + sumi2); + qs += 16; + q8 += 32; + sumi1 = sumi2 = 0; + for (int j = 0; j < 16; ++j) { + sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; + sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; + } + sumf += d2 * (sumi1 + sumi2); + qs += 16; + q8 += 32; + } + } + *s = sumf; +#endif +} + diff --git a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp b/ggml/src/ggml-cpu/arch/x86/repack.cpp similarity index 55% rename from ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp rename to ggml/src/ggml-cpu/arch/x86/repack.cpp index 74a31abb2..e7635a294 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +++ b/ggml/src/ggml-cpu/arch/x86/repack.cpp @@ -3,102 +3,26 @@ #include "ggml-common.h" #include "ggml-backend-impl.h" -#include "ggml-quants.h" #include "ggml-impl.h" #include "ggml-cpu.h" #include "ggml-cpu-impl.h" -#include "ggml-cpu-traits.h" +#include "traits.h" #include #include #include -#include #include // for qsort #include // for GGML_ASSERT -#include "ggml-cpu-aarch64.h" - -// TODO: move to include file? -template constexpr int QK_0() { - if constexpr (K == 4) { - return QK4_0; - } - if constexpr (K == 8) { - return QK8_0; - } - return -1; -} - -template struct block { - ggml_half d[N]; // deltas for N qK_0 blocks - int8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_0 blocks -}; - -// control size -static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding"); -static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding"); -static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding"); -static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding"); - -using block_q4_0x4 = block<4, 4>; -using block_q4_0x8 = block<4, 8>; -using block_q8_0x4 = block<8, 4>; -using block_q8_0x8 = block<8, 8>; - - -struct block_q4_Kx8 { - ggml_half d[8]; // super-block scale for quantized scales - ggml_half dmin[8]; // super-block scale for quantized mins - uint8_t scales[96]; // scales and mins, quantized with 6 bits - uint8_t qs[1024]; // 4--bit quants -}; - -static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding"); - -struct block_q8_Kx4 { - float d[4]; // delta - int8_t qs[QK_K * 4]; // quants - int16_t bsums[QK_K / 4]; // sum of quants in groups of 16 -}; - -static_assert(sizeof(block_q8_Kx4) == sizeof(float) * 4 + QK_K * 4 + (QK_K / 4) * sizeof(int16_t), "wrong q8_K block size/padding"); - -struct block_iq4_nlx4 { - ggml_half d[4]; // deltas for 4 iq4_nl blocks - uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks -}; - -static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding"); +#define GGML_CPU_CLANG_WORKAROUND +#include "../../repack.h" #if defined(__GNUC__) #pragma GCC diagnostic ignored "-Woverlength-strings" -#elif defined(_MSC_VER) -#pragma warning(disable: 4244 4267) // possible loss of data #endif #define UNUSED GGML_UNUSED -static inline int nearest_int(float fval) { - assert(fabsf(fval) <= 4194303.f); - float val = fval + 12582912.f; - int i; memcpy(&i, &val, sizeof(int)); - return (i & 0x007fffff) - 0x00400000; -} - -// Functions to create the interleaved data layout formats - -// interleave 4 block_q4_0s in blocks of blck_size_interleave -// returns an interleaved block_q4_0x4 -// in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks -// first, then interleave quants from 4 block_q4_0s in blocks of blck_size_interleave -// -// - in : an array of block_q4_0 pointers -// - blck_size_interleave : the block_q4_0 quants bytes are interleaved in blocks of -// blck_size_interleave bytes -// - xor_mask : the mask to convert the nibbles in block_q4_0 quants bytes -// from bias offset form to pure sign form (this saves subtract -// operations durin unpacking) -// #if defined(__AVX__) #if defined(__F16C__) #if defined(__AVX512F__) @@ -180,256 +104,84 @@ static inline __m256 __avx_rearranged_f32cx8_load(ggml_fp16_t *x, __m128i arrang #endif #endif +static inline int nearest_int(float fval) { + assert(fabsf(fval) <= 4194303.f); + float val = fval + 12582912.f; + int i; memcpy(&i, &val, sizeof(int)); + return (i & 0x007fffff) - 0x00400000; +} #if defined(__AVX2__) || defined(__AVX512F__) #if defined(__AVX512F__) -// add int16_t pairwise and return as 512 bit int vector -static inline __m512i sum_i16_pairs_int_32x16(const __m512i x) { +// add int16_t pairwise and return as 512 bit int vector, then add the accumulator +static inline __m512i sum_i16_pairs_acc_int32x16(const __m512i acc, const __m512i x) { const __m512i ones = _mm512_set1_epi16(1); - return _mm512_madd_epi16(ones, x); + return _mm512_add_epi32(acc, _mm512_madd_epi16(ones, x)); } -static inline __m512i mul_sum_us8_pairs_int32x16(const __m512i ax, const __m512i sy) { +static inline __m512i mul_sum_us8_pairs_acc_int32x16(const __m512i acc, const __m512i ax, const __m512i sy) { #if defined(__AVX512VNNI__) - const __m512i zero = _mm512_setzero_si512(); - return _mm512_dpbusd_epi32(zero, ax, sy); + return _mm512_dpbusd_epi32(acc, ax, sy); #else // Perform multiplication and create 16-bit values const __m512i dot = _mm512_maddubs_epi16(ax, sy); - return sum_i16_pairs_int_32x16(dot); + return sum_i16_pairs_acc_int32x16(acc, dot); #endif } -// multiply int8_t, add results pairwise twice and return as 512 bit int vector -static inline __m512i mul_sum_i8_pairs_int32x16(const __m512i x, const __m512i y) { +// multiply int8_t, add results pairwise twice and return as 512 bit int vector,then add the accumulator +static inline __m512i mul_sum_i8_pairs_acc_int32x16(const __m512i acc, const __m512i x, const __m512i y) { const __m512i zero = _mm512_setzero_si512(); // Get absolute values of x vectors const __m512i ax = _mm512_abs_epi8(x); // Sign the values of the y vectors __mmask64 blt0 = _mm512_movepi8_mask(x); const __m512i sy = _mm512_mask_sub_epi8(y, blt0, zero, y); - return mul_sum_us8_pairs_int32x16(ax, sy); + return mul_sum_us8_pairs_acc_int32x16(acc, ax, sy); } #endif -// add int16_t pairwise and return as 256 bit int vector -static inline __m256i sum_i16_pairs_int32x8(const __m256i x) { +// add int16_t pairwise and return as 256 bit int vector, then add the accumulator +static inline __m256i sum_i16_pairs_acc_int32x8(const __m256i acc, const __m256i x) { const __m256i ones = _mm256_set1_epi16(1); - return _mm256_madd_epi16(ones, x); + return _mm256_add_epi32(acc, _mm256_madd_epi16(ones, x)); } -static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) { +static inline __m256i mul_sum_us8_pairs_acc_int32x8(const __m256i acc, const __m256i ax, const __m256i sy) { #if defined(__AVX512VNNI__) && defined(__AVX512VL__) - const __m256i zero = _mm256_setzero_si256(); - return _mm256_dpbusd_epi32(zero, ax, sy); + return _mm256_dpbusd_epi32(acc, ax, sy); #elif defined(__AVXVNNI__) - const __m256i zero = _mm256_setzero_si256(); - return _mm256_dpbusd_avx_epi32(zero, ax, sy); + return _mm256_dpbusd_avx_epi32(acc, ax, sy); #else // Perform multiplication and create 16-bit values const __m256i dot = _mm256_maddubs_epi16(ax, sy); - return sum_i16_pairs_int32x8(dot); + return sum_i16_pairs_acc_int32x8(acc, dot); #endif } // Integer variant of the function defined in ggml-quants.c -// multiply int8_t, add results pairwise twice and return as 256 bit int vector -static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y) { -#if __AVXVNNIINT8__ - const __m256i zero = _mm256_setzero_si256(); - return _mm256_dpbssd_epi32(zero, x, y); +// multiply int8_t, add results pairwise twice and return as 256 bit int vector, then add the accumulator +static inline __m256i mul_sum_i8_pairs_acc_int32x8(const __m256i acc, const __m256i x, const __m256i y) { +#if defined(__AVXVNNIINT8__) + return _mm256_dpbssd_epi32(acc, x, y); #else // Get absolute values of x vectors const __m256i ax = _mm256_sign_epi8(x, x); // Sign the values of the y vectors const __m256i sy = _mm256_sign_epi8(y, x); - return mul_sum_us8_pairs_int32x8(ax, sy); + return mul_sum_us8_pairs_acc_int32x8(acc, ax, sy); #endif } #endif -static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; - -static void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { +void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK8_0 == 32); assert(k % QK8_0 == 0); const int nb = k / QK8_0; block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; -#if defined(__ARM_NEON) - float32x4_t srcv[4][8]; - float id[4]; - - for (int i = 0; i < nb; i++) { - float32x4_t asrcv[8]; - float32x4_t amaxv[8]; - - for (int row_iter = 0; row_iter < 4; row_iter++) { - for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j); - for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); - - for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); - for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); - for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); - - const float amax = vmaxvq_f32(amaxv[0]); - - const float d = amax / ((1 << 7) - 1); - id[row_iter] = d ? 1.0f / d : 0.0f; - - y[i].d[row_iter] = GGML_FP32_TO_FP16(d); - } - - for (int j = 0; j < 8; j++) { - float32x4_t v = vmulq_n_f32(srcv[0][j], id[0]); - int32x4_t vi = vcvtnq_s32_f32(v); - y[i].qs[16 * j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[16 * j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[16 * j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[16 * j + 3] = vgetq_lane_s32(vi, 3); - - v = vmulq_n_f32(srcv[1][j], id[1]); - vi = vcvtnq_s32_f32(v); - y[i].qs[16 * j + 4] = vgetq_lane_s32(vi, 0); - y[i].qs[16 * j + 5] = vgetq_lane_s32(vi, 1); - y[i].qs[16 * j + 6] = vgetq_lane_s32(vi, 2); - y[i].qs[16 * j + 7] = vgetq_lane_s32(vi, 3); - - v = vmulq_n_f32(srcv[2][j], id[2]); - vi = vcvtnq_s32_f32(v); - y[i].qs[16 * j + 8] = vgetq_lane_s32(vi, 0); - y[i].qs[16 * j + 9] = vgetq_lane_s32(vi, 1); - y[i].qs[16 * j + 10] = vgetq_lane_s32(vi, 2); - y[i].qs[16 * j + 11] = vgetq_lane_s32(vi, 3); - - v = vmulq_n_f32(srcv[3][j], id[3]); - vi = vcvtnq_s32_f32(v); - y[i].qs[16 * j + 12] = vgetq_lane_s32(vi, 0); - y[i].qs[16 * j + 13] = vgetq_lane_s32(vi, 1); - y[i].qs[16 * j + 14] = vgetq_lane_s32(vi, 2); - y[i].qs[16 * j + 15] = vgetq_lane_s32(vi, 3); - } - } -#else - // scalar - const int blck_size_interleave = 4; - float srcv[4][QK8_0]; - float id[4]; - - for (int i = 0; i < nb; i++) { - for (int row_iter = 0; row_iter < 4; row_iter++) { - float amax = 0.0f; // absolute max - - for (int j = 0; j < QK8_0; j++) { - srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; - amax = MAX(amax, fabsf(srcv[row_iter][j])); - } - - const float d = amax / ((1 << 7) - 1); - id[row_iter] = d ? 1.0f / d : 0.0f; - - y[i].d[row_iter] = GGML_FP32_TO_FP16(d); - } - - for (int j = 0; j < QK8_0 * 4; j++) { - int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; - int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; - src_offset += (j % blck_size_interleave); - - float x0 = srcv[src_id][src_offset] * id[src_id]; - y[i].qs[j] = roundf(x0); - } - } -#endif -} - -static void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { - assert(QK8_0 == 32); - assert(k % QK8_0 == 0); - const int nb = k / QK8_0; - - block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; - -#if defined(__ARM_NEON) - float32x4_t srcv[4][8]; - float id[4]; - - for (int i = 0; i < nb; i++) { - float32x4_t asrcv[8]; - float32x4_t amaxv[8]; - - for (int row_iter = 0; row_iter < 4; row_iter++) { - for (int j = 0; j < 8; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 32 + 4 * j); - for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]); - - for (int j = 0; j < 4; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]); - for (int j = 0; j < 2; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]); - for (int j = 0; j < 1; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]); - - const float amax = vmaxvq_f32(amaxv[0]); - - const float d = amax / ((1 << 7) - 1); - id[row_iter] = d ? 1.0f / d : 0.0f; - - y[i].d[row_iter] = GGML_FP32_TO_FP16(d); - } - - for (int j = 0; j < 4; j++) { - float32x4_t v = vmulq_n_f32(srcv[0][2 * j], id[0]); - int32x4_t vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 3] = vgetq_lane_s32(vi, 3); - v = vmulq_n_f32(srcv[0][2 * j + 1], id[0]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 4] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 5] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 6] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 7] = vgetq_lane_s32(vi, 3); - - v = vmulq_n_f32(srcv[1][2 * j], id[1]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 8] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 9] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 10] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 11] = vgetq_lane_s32(vi, 3); - v = vmulq_n_f32(srcv[1][2 * j + 1], id[1]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 12] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 13] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 14] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 15] = vgetq_lane_s32(vi, 3); - - v = vmulq_n_f32(srcv[2][2 * j], id[2]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 16] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 17] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 18] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 19] = vgetq_lane_s32(vi, 3); - v = vmulq_n_f32(srcv[2][2 * j + 1], id[2]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 20] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 21] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 22] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 23] = vgetq_lane_s32(vi, 3); - - v = vmulq_n_f32(srcv[3][2 * j], id[3]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 24] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 25] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 26] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 27] = vgetq_lane_s32(vi, 3); - v = vmulq_n_f32(srcv[3][2 * j + 1], id[3]); - vi = vcvtnq_s32_f32(v); - y[i].qs[32 * j + 28] = vgetq_lane_s32(vi, 0); - y[i].qs[32 * j + 29] = vgetq_lane_s32(vi, 1); - y[i].qs[32 * j + 30] = vgetq_lane_s32(vi, 2); - y[i].qs[32 * j + 31] = vgetq_lane_s32(vi, 3); - } - } -#elif defined(__AVX2__) || defined(__AVX__) +#if defined(__AVX2__) || defined(__AVX__) float id[4]; __m256 srcv[4][4]; __m256 idvec[4]; @@ -526,6 +278,7 @@ static void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGM #endif } } + #else // scalar const int blck_size_interleave = 8; @@ -559,7 +312,7 @@ static void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGM #endif } -static void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { +void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK_K == 256); assert(k % QK_K == 0); const int nb = k / QK_K; @@ -823,203 +576,7 @@ static void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGM #endif } -template -void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row); - -template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { - assert(nrow == 4); - UNUSED(nrow); - ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row); -} - -template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { - assert(nrow == 4); - UNUSED(nrow); - ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row); -} - -template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { - assert(nrow == 4); - UNUSED(nrow); - ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row); -} - -static void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; - - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) - if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx; - - for (int c = 0; c < nc; c += ncols_interleaved) { - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - float32x4_t acc = vdupq_n_f32(0); - for (int b = 0; b < nb; b++) { - int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs); - int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16); - int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32); - int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48); - float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d); - - int8x16_t a0 = vld1q_s8(a_ptr->qs); - int8x16_t a1 = vld1q_s8(a_ptr->qs + qk/2); - float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d); - - int32x4_t ret = vdupq_n_s32(0); - - ret = vdotq_laneq_s32(ret, b0 << 4, a0, 0); - ret = vdotq_laneq_s32(ret, b1 << 4, a0, 1); - ret = vdotq_laneq_s32(ret, b2 << 4, a0, 2); - ret = vdotq_laneq_s32(ret, b3 << 4, a0, 3); - - ret = vdotq_laneq_s32(ret, b0 & 0xf0U, a1, 0); - ret = vdotq_laneq_s32(ret, b1 & 0xf0U, a1, 1); - ret = vdotq_laneq_s32(ret, b2 & 0xf0U, a1, 2); - ret = vdotq_laneq_s32(ret, b3 & 0xf0U, a1, 3); - - acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4), - vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd))); - a_ptr++; - b_ptr++; - } - vst1q_f32(s, acc); - s += ncols_interleaved; - } - return; - } -#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) - float sumf[4]; - int sumi; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; - } - sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); - } - } - } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; - } -} - -static void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 8; - - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) - if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx; - - for (int c = 0; c < nc; c += ncols_interleaved) { - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - float32x4_t acc = vdupq_n_f32(0); - for (int b = 0; b < nb; b++) { - int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs); - int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16); - int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32); - int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48); - float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d); - - int8x16_t a0 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs); - int8x16_t a1 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 1); - int8x16_t a2 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 2); - int8x16_t a3 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 3); - float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d); - - int32x4_t ret0 = vdupq_n_s32(0); - int32x4_t ret1 = vdupq_n_s32(0); - - ret0 = vdotq_s32(ret0, b0 << 4, a0); - ret1 = vdotq_s32(ret1, b1 << 4, a0); - ret0 = vdotq_s32(ret0, b2 << 4, a1); - ret1 = vdotq_s32(ret1, b3 << 4, a1); - - ret0 = vdotq_s32(ret0, b0 & 0xf0U, a2); - ret1 = vdotq_s32(ret1, b1 & 0xf0U, a2); - ret0 = vdotq_s32(ret0, b2 & 0xf0U, a3); - ret1 = vdotq_s32(ret1, b3 & 0xf0U, a3); - - int32x4_t ret = vpaddq_s32(ret0, ret1); - - acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4), - vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd))); - a_ptr++; - b_ptr++; - } - vst1q_f32(s, acc); - s += ncols_interleaved; - } - return; - } -#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) - float sumf[4]; - int sumi; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; - } - sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); - } - } - } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; - } -} - -static void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; const int ncols_interleaved = 8; @@ -1038,75 +595,7 @@ static void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c UNUSED(ncols_interleaved); UNUSED(blocklen); -#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) -#if defined(__ARM_FEATURE_SVE) - if (ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) { - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - - __asm__ __volatile__( - "ptrue p0.b\n" - "add %x[b_ptr], %x[b_ptr], #0x10\n" - "1:" // Column loop - "add x22, %x[a_ptr], #0x2\n" - "mov z31.b, #0x0\n" - "mov x21, %x[nb]\n" - "2:" // Block loop - "ld1b { z30.b }, p0/Z, [%x[b_ptr]]\n" - "ld1b { z29.b }, p0/Z, [%x[b_ptr], #1, MUL VL]\n" - "mov z28.s, #0x0\n" - "mov z27.s, #0x0\n" - "ld1rd { z26.d }, p0/Z, [x22]\n" - "ld1b { z25.b }, p0/Z, [%x[b_ptr], #2, MUL VL]\n" - "sub x20, x22, #0x2\n" - "sub x21, x21, #0x1\n" - "ld1b { z24.b }, p0/Z, [%x[b_ptr], #3, MUL VL]\n" - "ld1rd { z23.d }, p0/Z, [x22, #8]\n" - "lsl z22.b, z30.b, #0x4\n" - "lsl z16.b, z29.b, #0x4\n" - "and z30.b, z30.b, #0xf0\n" - "and z29.b, z29.b, #0xf0\n" - "ld1rd { z21.d }, p0/Z, [x22, #16]\n" - "ld1rd { z20.d }, p0/Z, [x22, #24]\n" - "lsl z19.b, z25.b, #0x4\n" - "and z25.b, z25.b, #0xf0\n" - "ld1rh { z17.h }, p0/Z, [x20]\n" - "ld1h { z18.s }, p0/Z, [%x[b_ptr], #-1, MUL VL]\n" - "sdot z28.s, z22.b, z26.b\n" - "sdot z27.s, z16.b, z26.b\n" - "lsl z16.b, z24.b, #0x4\n" - "add x22, x22, #0x22\n" - "and z24.b, z24.b, #0xf0\n" - "add %x[b_ptr], %x[b_ptr], #0x90\n" - "fcvt z17.s, p0/m, z17.h\n" - "fcvt z18.s, p0/m, z18.h\n" - "sdot z28.s, z19.b, z23.b\n" - "sdot z27.s, z16.b, z23.b\n" - "fmul z18.s, z18.s, z17.s\n" - "sdot z28.s, z30.b, z21.b\n" - "sdot z27.s, z29.b, z21.b\n" - "sdot z28.s, z25.b, z20.b\n" - "sdot z27.s, z24.b, z20.b\n" - "uzp1 z17.s, z28.s, z27.s\n" - "uzp2 z16.s, z28.s, z27.s\n" - "add z17.s, z17.s, z16.s\n" - "asr z17.s, z17.s, #0x4\n" - "scvtf z17.s, p0/m, z17.s\n" - "fmla z31.s, p0/M, z17.s, z18.s\n" - "cbnz x21, 2b\n" - "sub %x[nc], %x[nc], #0x8\n" - "st1w { z31.s }, p0, [%x[res_ptr]]\n" - "add %x[res_ptr], %x[res_ptr], #0x20\n" - "cbnz %x[nc], 1b\n" - : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) - : [a_ptr] "r" (a_ptr), [nb] "r" (nb) - : "memory", "p0", "x20", "x21", "x22", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" - ); - return; - } -#endif // #if defined(__ARM_FEATURE_SVE) -#elif defined(__AVX2__) +#if defined(__AVX2__) // Lookup table to convert signed nibbles to signed bytes __m256i signextendlut = _mm256_castsi128_si256(_mm_set_epi8(-1, -2, -3, -4, -5, -6, -7, -8, 7, 6, 5, 4, 3, 2, 1, 0)); signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0); @@ -1175,17 +664,17 @@ static void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c // ........................................................................... // B0(28-31) B4(28-31) B1(28-31) B5(28-31) B2(28-31) B6(28-31) B3(28-31) B7(28-31) with A0(28-31) - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0))); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85))); + iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_0 ,_mm256_shuffle_epi32(rhs_vec_4567_0, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 0)); + iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_0, 177) ,rhs_vec_4567_0, 170), _mm256_shuffle_epi32(lhs_vec_0, 85)); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170))); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255))); + iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_1 ,_mm256_shuffle_epi32(rhs_vec_4567_1, 177), 170), _mm256_shuffle_epi32(lhs_vec_0, 170)); + iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_1, 177) ,rhs_vec_4567_1, 170), _mm256_shuffle_epi32(lhs_vec_0, 255)); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0))); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85))); + iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_2 ,_mm256_shuffle_epi32(rhs_vec_4567_2, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 0)); + iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_2, 177) ,rhs_vec_4567_2, 170), _mm256_shuffle_epi32(lhs_vec_1, 85)); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170))); - iacc = _mm256_add_epi32(iacc, mul_sum_i8_pairs_int32x8(_mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255))); + iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170)); + iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255)); // Accumulated values multipled with appropriate scales acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row); @@ -1197,74 +686,8 @@ static void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c } } return; -#elif defined(__riscv_v_intrinsic) - if (__riscv_vlenb() >= QK4_0) { - const size_t vl = QK4_0; - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); - - vfloat32m1_t sumf = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); - for (int l = 0; l < nb; l++) { - const int64_t a0 = *(const int64_t *)&a_ptr[l].qs[0]; - const int64_t a1 = *(const int64_t *)&a_ptr[l].qs[8]; - const int64_t a2 = *(const int64_t *)&a_ptr[l].qs[16]; - const int64_t a3 = *(const int64_t *)&a_ptr[l].qs[24]; - __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment - const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a0, vl / 4)); - const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a1, vl / 4)); - const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a2, vl / 4)); - const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(a3, vl / 4)); - - const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4); - const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4); - const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4); - const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0); - const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1); - const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0); - const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1); - - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); - const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); - - const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_hi_m)); - const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); - const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); - const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); - const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); - const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); - const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); - const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); - const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); - const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); - const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); - const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); - - // vector version needs Zvfhmin extension - const float a_scale = GGML_FP16_TO_FP32(a_ptr[l].d); - const float b_scales[8] = { - GGML_FP16_TO_FP32(b_ptr[l].d[0]), - GGML_FP16_TO_FP32(b_ptr[l].d[1]), - GGML_FP16_TO_FP32(b_ptr[l].d[2]), - GGML_FP16_TO_FP32(b_ptr[l].d[3]), - GGML_FP16_TO_FP32(b_ptr[l].d[4]), - GGML_FP16_TO_FP32(b_ptr[l].d[5]), - GGML_FP16_TO_FP32(b_ptr[l].d[6]), - GGML_FP16_TO_FP32(b_ptr[l].d[7]) - }; - const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4); - const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scale, vl / 4); - sumf = __riscv_vfmacc_vv_f32m1(sumf, tmp1, b_scales_vec, vl / 4); - } - __riscv_vse32_v_f32m1(s + x * ncols_interleaved, sumf, vl / 4); - } - return; - } -#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) +#endif { float sumf[8]; int sumi; @@ -1292,7 +715,7 @@ static void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c } } -static void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; const int ncols_interleaved = 8; @@ -1566,1074 +989,7 @@ static void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, c #endif } - -static void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; - - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) - if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { - const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl); - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - float * res_ptr = s; - - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); - - float32x4_t sumf = vdupq_n_f32(0); - for (int l = 0; l < nb; l++) { - uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0); - uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16); - uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32); - uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48); - - int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4); - int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F); - int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4); - int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F); - int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4); - int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F); - int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4); - int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F); - - int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0); - int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16); - - int32x4_t sumi = vdupq_n_s32(0); - sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0); - sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0); - sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1); - sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1); - sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2); - sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2); - sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3); - sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3); - - float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d)); - float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d)); - float32x4_t d = a_d * b_d; - - sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi)); - } - - vst1q_f32(res_ptr + x * 4, sumf); - } - return; - } -#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) - { - float sumf[4]; - int sumi; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; - const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); - } - sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); - } - } - } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; - } - } -} - -static void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) - if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - size_t res_stride = bs * sizeof(float); - - __asm__ __volatile__( - "mov x10, %x[nr]\n" - "mov x9, #0x88\n" - "cmp x10, #0x10\n" - "mul x9, %x[nb], x9\n" - "blt 4f\n" - "1:" // Row loop - "add x28, %x[b_ptr], #0x8\n" - "mov x27, %x[nc]\n" - "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" - "2:" // Column loop - "add x25, %x[a_ptr], #0x8\n" - "movi v15.16b, #0x0\n" - "movi v19.16b, #0x0\n" - "mov x24, %x[nb]\n" - "add x23, x25, x9\n" - "movi v18.16b, #0x0\n" - "movi v14.16b, #0x0\n" - "add x22, x23, x9\n" - "movi v11.16b, #0x0\n" - "movi v13.16b, #0x0\n" - "add x21, x22, x9\n" - "movi v23.16b, #0x0\n" - "movi v16.16b, #0x0\n" - "movi v25.16b, #0x0\n" - "movi v7.16b, #0x0\n" - "movi v0.16b, #0x0\n" - "movi v4.16b, #0x0\n" - "movi v5.16b, #0x0\n" - "movi v21.16b, #0x0\n" - "movi v8.16b, #0x0\n" - "movi v1.16b, #0x0\n" - "3:" // Block loop - "ldr q3, [x28, #0x0]\n" - "ldr q31, [x25, #0x0]\n" - "movi v28.16b, #0x4\n" - "movi v10.4s, #0x0\n" - "ldr q22, [x28, #0x10]\n" - "ldr q6, [x25, #0x10]\n" - "movi v29.4s, #0x0\n" - "movi v9.4s, #0x0\n" - "ldr q27, [x28, #0x20]\n" - "ldr q30, [x28, #0x30]\n" - "movi v20.4s, #0x0\n" - "movi v24.16b, #0xf0\n" - "ldr d2, [x25, #-0x8]\n" - "ldr d26, [x23, #-0x8]\n" - "sshl v12.16b, v3.16b, v28.16b\n" - "sub x20, x28, #0x8\n" - "ldr d17, [x20, #0x0]\n" - "and v3.16b, v3.16b, v24.16b\n" - "subs x24, x24, #0x1\n" - "add x28, x28, #0x48\n" - ".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n" - ".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n" - ".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n" - ".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n" - "sshl v31.16b, v22.16b, v28.16b\n" - "and v22.16b, v22.16b, v24.16b\n" - "fcvtl v17.4s, v17.4h\n" - "fcvtl v2.4s, v2.4h\n" - "fcvtl v26.4s, v26.4h\n" - ".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n" - ".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n" - ".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n" - ".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n" - "sshl v6.16b, v27.16b, v28.16b\n" - "sshl v28.16b, v30.16b, v28.16b\n" - "and v27.16b, v27.16b, v24.16b\n" - "and v30.16b, v30.16b, v24.16b\n" - "ldr q24, [x25, #0x20]\n" - ".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n" - ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" - ".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n" - ".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x30]\n" - ".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n" - ".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n" - ".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n" - ".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x40]\n" - ".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n" - ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" - ".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n" - ".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x50]\n" - ".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n" - ".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n" - ".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n" - ".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x60]\n" - ".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n" - ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" - ".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n" - ".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x70]\n" - "add x25, x25, #0x88\n" - ".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n" - ".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n" - ".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n" - ".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n" - "fmul v24.4s, v17.4s, v2.s[0]\n" - "scvtf v10.4s, v10.4s, #0x4\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "scvtf v9.4s, v9.4s, #0x4\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "fmla v15.4s, v10.4s, v24.4s\n" - "ldr q24, [x23, #0x0]\n" - "fmul v10.4s, v17.4s, v2.s[1]\n" - "fmla v19.4s, v29.4s, v10.4s\n" - "ldr q10, [x23, #0x10]\n" - "fmul v29.4s, v17.4s, v2.s[2]\n" - "fmul v2.4s, v17.4s, v2.s[3]\n" - "fmla v18.4s, v9.4s, v29.4s\n" - "movi v9.4s, #0x0\n" - "movi v29.4s, #0x0\n" - ".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n" - ".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n" - "fmla v14.4s, v20.4s, v2.4s\n" - "movi v20.4s, #0x0\n" - "movi v2.4s, #0x0\n" - ".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n" - ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" - "ldr q24, [x23, #0x20]\n" - ".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n" - ".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n" - ".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n" - ".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n" - "ldr q10, [x23, #0x30]\n" - ".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n" - ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" - ".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n" - ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" - "ldr q24, [x23, #0x40]\n" - ".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n" - ".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n" - ".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n" - ".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n" - "ldr q10, [x23, #0x50]\n" - ".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n" - ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" - ".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n" - ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" - "ldr q24, [x23, #0x60]\n" - ".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n" - ".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n" - ".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n" - ".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n" - "ldr q10, [x23, #0x70]\n" - "add x23, x23, #0x88\n" - ".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n" - ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" - ".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n" - ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" - "ldr q24, [x22, #0x0]\n" - ".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n" - ".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n" - ".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n" - ".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n" - "fmul v10.4s, v17.4s, v26.s[0]\n" - "scvtf v9.4s, v9.4s, #0x4\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "scvtf v2.4s, v2.4s, #0x4\n" - "fmla v11.4s, v9.4s, v10.4s\n" - "ldr q9, [x22, #0x10]\n" - "fmul v10.4s, v17.4s, v26.s[1]\n" - "fmla v13.4s, v29.4s, v10.4s\n" - "ldr d29, [x22, #-0x8]\n" - "fmul v10.4s, v17.4s, v26.s[2]\n" - "fmul v26.4s, v17.4s, v26.s[3]\n" - "fcvtl v29.4s, v29.4h\n" - "fmla v23.4s, v20.4s, v10.4s\n" - "movi v20.4s, #0x0\n" - "movi v10.4s, #0x0\n" - "fmla v16.4s, v2.4s, v26.4s\n" - "movi v26.4s, #0x0\n" - "movi v2.4s, #0x0\n" - ".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n" - ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" - ".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n" - ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" - "ldr q24, [x22, #0x20]\n" - ".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n" - ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" - ".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n" - ".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n" - "ldr q9, [x22, #0x30]\n" - ".inst 0x4f98e0d4 // sdot v20.4s, v6.16b, v24.4b[0]\n" - ".inst 0x4fb8e0ca // sdot v10.4s, v6.16b, v24.4b[1]\n" - ".inst 0x4f98e8da // sdot v26.4s, v6.16b, v24.4b[2]\n" - ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" - "ldr q24, [x22, #0x40]\n" - ".inst 0x4f89e394 // sdot v20.4s, v28.16b, v9.4b[0]\n" - ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n" - ".inst 0x4f89eb9a // sdot v26.4s, v28.16b, v9.4b[2]\n" - ".inst 0x4fa9eb82 // sdot v2.4s, v28.16b, v9.4b[3]\n" - "ldr q9, [x22, #0x50]\n" - ".inst 0x4f98e074 // sdot v20.4s, v3.16b, v24.4b[0]\n" - ".inst 0x4fb8e06a // sdot v10.4s, v3.16b, v24.4b[1]\n" - ".inst 0x4f98e87a // sdot v26.4s, v3.16b, v24.4b[2]\n" - ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" - "ldr q24, [x22, #0x60]\n" - ".inst 0x4f89e2d4 // sdot v20.4s, v22.16b, v9.4b[0]\n" - ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n" - ".inst 0x4f89eada // sdot v26.4s, v22.16b, v9.4b[2]\n" - ".inst 0x4fa9eac2 // sdot v2.4s, v22.16b, v9.4b[3]\n" - "ldr q9, [x22, #0x70]\n" - "add x22, x22, #0x88\n" - ".inst 0x4f98e374 // sdot v20.4s, v27.16b, v24.4b[0]\n" - ".inst 0x4fb8e36a // sdot v10.4s, v27.16b, v24.4b[1]\n" - ".inst 0x4f98eb7a // sdot v26.4s, v27.16b, v24.4b[2]\n" - ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" - "ldr q24, [x21, #0x0]\n" - ".inst 0x4f89e3d4 // sdot v20.4s, v30.16b, v9.4b[0]\n" - ".inst 0x4fa9e3ca // sdot v10.4s, v30.16b, v9.4b[1]\n" - ".inst 0x4f89ebda // sdot v26.4s, v30.16b, v9.4b[2]\n" - ".inst 0x4fa9ebc2 // sdot v2.4s, v30.16b, v9.4b[3]\n" - "fmul v9.4s, v17.4s, v29.s[0]\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "scvtf v10.4s, v10.4s, #0x4\n" - "scvtf v26.4s, v26.4s, #0x4\n" - "scvtf v2.4s, v2.4s, #0x4\n" - "fmla v25.4s, v20.4s, v9.4s\n" - "ldr q9, [x21, #0x10]\n" - "fmul v20.4s, v17.4s, v29.s[1]\n" - "fmla v7.4s, v10.4s, v20.4s\n" - "ldr d20, [x21, #-0x8]\n" - "fmul v10.4s, v17.4s, v29.s[2]\n" - "fmul v29.4s, v17.4s, v29.s[3]\n" - "fcvtl v20.4s, v20.4h\n" - "fmla v0.4s, v26.4s, v10.4s\n" - "movi v26.4s, #0x0\n" - "movi v10.4s, #0x0\n" - "fmla v4.4s, v2.4s, v29.4s\n" - "movi v2.4s, #0x0\n" - "movi v29.4s, #0x0\n" - ".inst 0x4f98e19a // sdot v26.4s, v12.16b, v24.4b[0]\n" - ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" - ".inst 0x4f98e982 // sdot v2.4s, v12.16b, v24.4b[2]\n" - ".inst 0x4fb8e99d // sdot v29.4s, v12.16b, v24.4b[3]\n" - "ldr q12, [x21, #0x20]\n" - "fmul v24.4s, v17.4s, v20.s[0]\n" - ".inst 0x4f89e3fa // sdot v26.4s, v31.16b, v9.4b[0]\n" - ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" - ".inst 0x4f89ebe2 // sdot v2.4s, v31.16b, v9.4b[2]\n" - ".inst 0x4fa9ebfd // sdot v29.4s, v31.16b, v9.4b[3]\n" - "ldr q9, [x21, #0x30]\n" - "fmul v31.4s, v17.4s, v20.s[1]\n" - ".inst 0x4f8ce0da // sdot v26.4s, v6.16b, v12.4b[0]\n" - ".inst 0x4face0ca // sdot v10.4s, v6.16b, v12.4b[1]\n" - ".inst 0x4f8ce8c2 // sdot v2.4s, v6.16b, v12.4b[2]\n" - ".inst 0x4face8dd // sdot v29.4s, v6.16b, v12.4b[3]\n" - "ldr q12, [x21, #0x40]\n" - "fmul v6.4s, v17.4s, v20.s[2]\n" - "fmul v20.4s, v17.4s, v20.s[3]\n" - ".inst 0x4f89e39a // sdot v26.4s, v28.16b, v9.4b[0]\n" - ".inst 0x4fa9e38a // sdot v10.4s, v28.16b, v9.4b[1]\n" - ".inst 0x4f89eb82 // sdot v2.4s, v28.16b, v9.4b[2]\n" - ".inst 0x4fa9eb9d // sdot v29.4s, v28.16b, v9.4b[3]\n" - "ldr q9, [x21, #0x50]\n" - ".inst 0x4f8ce07a // sdot v26.4s, v3.16b, v12.4b[0]\n" - ".inst 0x4face06a // sdot v10.4s, v3.16b, v12.4b[1]\n" - ".inst 0x4f8ce862 // sdot v2.4s, v3.16b, v12.4b[2]\n" - ".inst 0x4face87d // sdot v29.4s, v3.16b, v12.4b[3]\n" - "ldr q12, [x21, #0x60]\n" - ".inst 0x4f89e2da // sdot v26.4s, v22.16b, v9.4b[0]\n" - ".inst 0x4fa9e2ca // sdot v10.4s, v22.16b, v9.4b[1]\n" - ".inst 0x4f89eac2 // sdot v2.4s, v22.16b, v9.4b[2]\n" - ".inst 0x4fa9eadd // sdot v29.4s, v22.16b, v9.4b[3]\n" - "ldr q17, [x21, #0x70]\n" - "add x21, x21, #0x88\n" - ".inst 0x4f8ce37a // sdot v26.4s, v27.16b, v12.4b[0]\n" - ".inst 0x4face36a // sdot v10.4s, v27.16b, v12.4b[1]\n" - ".inst 0x4f8ceb62 // sdot v2.4s, v27.16b, v12.4b[2]\n" - ".inst 0x4faceb7d // sdot v29.4s, v27.16b, v12.4b[3]\n" - ".inst 0x4f91e3da // sdot v26.4s, v30.16b, v17.4b[0]\n" - ".inst 0x4fb1e3ca // sdot v10.4s, v30.16b, v17.4b[1]\n" - ".inst 0x4f91ebc2 // sdot v2.4s, v30.16b, v17.4b[2]\n" - ".inst 0x4fb1ebdd // sdot v29.4s, v30.16b, v17.4b[3]\n" - "scvtf v26.4s, v26.4s, #0x4\n" - "scvtf v10.4s, v10.4s, #0x4\n" - "fmla v5.4s, v26.4s, v24.4s\n" - "scvtf v2.4s, v2.4s, #0x4\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "fmla v21.4s, v10.4s, v31.4s\n" - "fmla v8.4s, v2.4s, v6.4s\n" - "fmla v1.4s, v29.4s, v20.4s\n" - "bgt 3b\n" - "mov x20, %x[res_ptr]\n" - "subs x27, x27, #0x4\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "str q15, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q19, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q18, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q14, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q11, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q13, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q23, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q16, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q25, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q7, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q0, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q4, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q5, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q21, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q8, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q1, [x20, #0x0]\n" - "bne 2b\n" - "mov x20, #0x4\n" - "sub x10, x10, #0x10\n" - "cmp x10, #0x10\n" - "mov %x[res_ptr], x26\n" - "madd %x[a_ptr], x20, x9, %x[a_ptr]\n" - "bge 1b\n" - "4:" // Row loop skip - "cbz x10, 9f\n" - "5:" // Row tail: Row loop - "add x24, %x[b_ptr], #0x8\n" - "mov x23, %x[nc]\n" - "add x22, %x[res_ptr], %x[res_stride], LSL #2\n" - "6:" // Row tail: Column loop - "movi v15.16b, #0x0\n" - "movi v19.16b, #0x0\n" - "add x25, %x[a_ptr], #0x8\n" - "mov x21, %x[nb]\n" - "movi v18.16b, #0x0\n" - "movi v14.16b, #0x0\n" - "7:" // Row tail: Block loop - "ldr q7, [x24, #0x0]\n" - "ldr q5, [x25, #0x0]\n" - "movi v9.16b, #0x4\n" - "movi v4.4s, #0x0\n" - "ldr q3, [x24, #0x10]\n" - "ldr q2, [x25, #0x10]\n" - "movi v1.4s, #0x0\n" - "movi v0.4s, #0x0\n" - "ldr q13, [x24, #0x20]\n" - "ldr q31, [x25, #0x20]\n" - "movi v30.4s, #0x0\n" - "movi v29.16b, #0xf0\n" - "ldr q28, [x24, #0x30]\n" - "ldr q27, [x25, #0x30]\n" - "sshl v20.16b, v7.16b, v9.16b\n" - "sub x20, x24, #0x8\n" - "ldr q26, [x25, #0x40]\n" - "ldr q25, [x25, #0x50]\n" - "sshl v17.16b, v3.16b, v9.16b\n" - "and v7.16b, v7.16b, v29.16b\n" - "ldr q24, [x25, #0x60]\n" - "ldr q16, [x25, #0x70]\n" - "sshl v22.16b, v13.16b, v9.16b\n" - "and v3.16b, v3.16b, v29.16b\n" - "ldr d21, [x20, #0x0]\n" - "ldr d12, [x25, #-0x8]\n" - ".inst 0x4f85e284 // sdot v4.4s, v20.16b, v5.4b[0]\n" - ".inst 0x4fa5e281 // sdot v1.4s, v20.16b, v5.4b[1]\n" - ".inst 0x4f85ea80 // sdot v0.4s, v20.16b, v5.4b[2]\n" - ".inst 0x4fa5ea9e // sdot v30.4s, v20.16b, v5.4b[3]\n" - "sshl v9.16b, v28.16b, v9.16b\n" - "subs x21, x21, #0x1\n" - "and v13.16b, v13.16b, v29.16b\n" - "and v28.16b, v28.16b, v29.16b\n" - "add x25, x25, #0x88\n" - "add x24, x24, #0x48\n" - "fcvtl v21.4s, v21.4h\n" - "fcvtl v12.4s, v12.4h\n" - ".inst 0x4f82e224 // sdot v4.4s, v17.16b, v2.4b[0]\n" - ".inst 0x4fa2e221 // sdot v1.4s, v17.16b, v2.4b[1]\n" - ".inst 0x4f82ea20 // sdot v0.4s, v17.16b, v2.4b[2]\n" - ".inst 0x4fa2ea3e // sdot v30.4s, v17.16b, v2.4b[3]\n" - "fmul v11.4s, v21.4s, v12.s[0]\n" - "fmul v23.4s, v21.4s, v12.s[1]\n" - "fmul v17.4s, v21.4s, v12.s[2]\n" - ".inst 0x4f9fe2c4 // sdot v4.4s, v22.16b, v31.4b[0]\n" - "fmul v6.4s, v21.4s, v12.s[3]\n" - ".inst 0x4fbfe2c1 // sdot v1.4s, v22.16b, v31.4b[1]\n" - ".inst 0x4f9feac0 // sdot v0.4s, v22.16b, v31.4b[2]\n" - ".inst 0x4fbfeade // sdot v30.4s, v22.16b, v31.4b[3]\n" - ".inst 0x4f9be124 // sdot v4.4s, v9.16b, v27.4b[0]\n" - ".inst 0x4fbbe121 // sdot v1.4s, v9.16b, v27.4b[1]\n" - ".inst 0x4f9be920 // sdot v0.4s, v9.16b, v27.4b[2]\n" - ".inst 0x4fbbe93e // sdot v30.4s, v9.16b, v27.4b[3]\n" - ".inst 0x4f9ae0e4 // sdot v4.4s, v7.16b, v26.4b[0]\n" - ".inst 0x4fbae0e1 // sdot v1.4s, v7.16b, v26.4b[1]\n" - ".inst 0x4f9ae8e0 // sdot v0.4s, v7.16b, v26.4b[2]\n" - ".inst 0x4fbae8fe // sdot v30.4s, v7.16b, v26.4b[3]\n" - ".inst 0x4f99e064 // sdot v4.4s, v3.16b, v25.4b[0]\n" - ".inst 0x4fb9e061 // sdot v1.4s, v3.16b, v25.4b[1]\n" - ".inst 0x4f99e860 // sdot v0.4s, v3.16b, v25.4b[2]\n" - ".inst 0x4fb9e87e // sdot v30.4s, v3.16b, v25.4b[3]\n" - ".inst 0x4f98e1a4 // sdot v4.4s, v13.16b, v24.4b[0]\n" - ".inst 0x4fb8e1a1 // sdot v1.4s, v13.16b, v24.4b[1]\n" - ".inst 0x4f98e9a0 // sdot v0.4s, v13.16b, v24.4b[2]\n" - ".inst 0x4fb8e9be // sdot v30.4s, v13.16b, v24.4b[3]\n" - ".inst 0x4f90e384 // sdot v4.4s, v28.16b, v16.4b[0]\n" - ".inst 0x4fb0e381 // sdot v1.4s, v28.16b, v16.4b[1]\n" - ".inst 0x4f90eb80 // sdot v0.4s, v28.16b, v16.4b[2]\n" - ".inst 0x4fb0eb9e // sdot v30.4s, v28.16b, v16.4b[3]\n" - "scvtf v4.4s, v4.4s, #0x4\n" - "scvtf v1.4s, v1.4s, #0x4\n" - "scvtf v0.4s, v0.4s, #0x4\n" - "fmla v15.4s, v4.4s, v11.4s\n" - "scvtf v30.4s, v30.4s, #0x4\n" - "fmla v19.4s, v1.4s, v23.4s\n" - "fmla v18.4s, v0.4s, v17.4s\n" - "fmla v14.4s, v30.4s, v6.4s\n" - "bgt 7b\n" - "mov x20, %x[res_ptr]\n" - "cmp x10, #0x1\n" - "str q15, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x10, #0x2\n" - "str q19, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x10, #0x3\n" - "str q18, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "str q14, [x20, #0x0]\n" - "8:" // Row tail: Accumulator store skip - "subs x23, x23, #0x4\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "bne 6b\n" - "subs x10, x10, #0x4\n" - "add %x[a_ptr], %x[a_ptr], x9\n" - "mov %x[res_ptr], x22\n" - "bgt 5b\n" - "9:" // Row tail: Row loop skip - : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) - : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) - : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" - ); - return; - } -#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) - { - float sumf[4][4]; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; - } - sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } - } -} - -static void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 8; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) - if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - size_t res_stride = bs * sizeof(float); - - __asm__ __volatile__( - "mov x10, %x[nr]\n" - "mov x9, #0x88\n" - "cmp x10, #0x10\n" - "mul x9, %x[nb], x9\n" - "blt 4f\n" - "1:" // Row loop - "add x28, %x[b_ptr], #0x8\n" - "mov x27, %x[nc]\n" - "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" - "2:" // Column loop - "add x25, %x[a_ptr], #0x8\n" - "movi v2.16b, #0x0\n" - "movi v10.16b, #0x0\n" - "mov x24, %x[nb]\n" - "add x23, x25, x9\n" - "movi v12.16b, #0x0\n" - "movi v28.16b, #0x0\n" - "add x22, x23, x9\n" - "movi v11.16b, #0x0\n" - "movi v13.16b, #0x0\n" - "add x21, x22, x9\n" - "movi v22.16b, #0x0\n" - "movi v23.16b, #0x0\n" - "movi v25.16b, #0x0\n" - "movi v5.16b, #0x0\n" - "movi v7.16b, #0x0\n" - "movi v4.16b, #0x0\n" - "movi v6.16b, #0x0\n" - "movi v30.16b, #0x0\n" - "movi v24.16b, #0x0\n" - "movi v14.16b, #0x0\n" - "3:" // Block loop - "ldr q21, [x28, #0x0]\n" - "ldr q16, [x28, #0x10]\n" - "movi v1.16b, #0x4\n" - "movi v19.4s, #0x0\n" - "ldr q27, [x25, #0x0]\n" - "ldr q15, [x25, #0x10]\n" - "movi v26.4s, #0x0\n" - "movi v18.4s, #0x0\n" - "ldr q29, [x28, #0x20]\n" - "ldr q3, [x28, #0x30]\n" - "movi v17.4s, #0x0\n" - "movi v0.16b, #0xf0\n" - "ldr d20, [x25, #-0x8]\n" - "ldr d9, [x23, #-0x8]\n" - "sshl v8.16b, v21.16b, v1.16b\n" - "sshl v31.16b, v16.16b, v1.16b\n" - "and v21.16b, v21.16b, v0.16b\n" - "and v16.16b, v16.16b, v0.16b\n" - "sub x20, x28, #0x8\n" - "subs x24, x24, #0x1\n" - "add x28, x28, #0x48\n" - ".inst 0x4e88a773 // smmla v19.4s, v27.16b, v8.16b\n" - ".inst 0x4e9fa77a // smmla v26.4s, v27.16b, v31.16b\n" - "ldr q27, [x25, #0x20]\n" - ".inst 0x4e88a5f2 // smmla v18.4s, v15.16b, v8.16b\n" - ".inst 0x4e9fa5f1 // smmla v17.4s, v15.16b, v31.16b\n" - "sshl v15.16b, v29.16b, v1.16b\n" - "sshl v1.16b, v3.16b, v1.16b\n" - "and v29.16b, v29.16b, v0.16b\n" - "and v3.16b, v3.16b, v0.16b\n" - "ldr q0, [x25, #0x30]\n" - "fcvtl v20.4s, v20.4h\n" - ".inst 0x4e8fa773 // smmla v19.4s, v27.16b, v15.16b\n" - "fcvtl v9.4s, v9.4h\n" - ".inst 0x4e81a77a // smmla v26.4s, v27.16b, v1.16b\n" - "ldr q27, [x25, #0x40]\n" - ".inst 0x4e8fa412 // smmla v18.4s, v0.16b, v15.16b\n" - ".inst 0x4e81a411 // smmla v17.4s, v0.16b, v1.16b\n" - "ldr q0, [x25, #0x50]\n" - ".inst 0x4e95a773 // smmla v19.4s, v27.16b, v21.16b\n" - ".inst 0x4e90a77a // smmla v26.4s, v27.16b, v16.16b\n" - "ldr q27, [x25, #0x60]\n" - ".inst 0x4e95a412 // smmla v18.4s, v0.16b, v21.16b\n" - ".inst 0x4e90a411 // smmla v17.4s, v0.16b, v16.16b\n" - "ldr q0, [x25, #0x70]\n" - "add x25, x25, #0x88\n" - ".inst 0x4e9da773 // smmla v19.4s, v27.16b, v29.16b\n" - ".inst 0x4e83a77a // smmla v26.4s, v27.16b, v3.16b\n" - "ldr d27, [x20, #0x0]\n" - ".inst 0x4e9da412 // smmla v18.4s, v0.16b, v29.16b\n" - ".inst 0x4e83a411 // smmla v17.4s, v0.16b, v3.16b\n" - "fcvtl v27.4s, v27.4h\n" - "uzp1 v0.2d, v19.2d, v26.2d\n" - "uzp2 v26.2d, v19.2d, v26.2d\n" - "fmul v19.4s, v27.4s, v20.s[0]\n" - "scvtf v0.4s, v0.4s, #0x4\n" - "scvtf v26.4s, v26.4s, #0x4\n" - "fmla v2.4s, v0.4s, v19.4s\n" - "ldr q19, [x23, #0x0]\n" - "uzp1 v0.2d, v18.2d, v17.2d\n" - "uzp2 v18.2d, v18.2d, v17.2d\n" - "fmul v17.4s, v27.4s, v20.s[1]\n" - "scvtf v0.4s, v0.4s, #0x4\n" - "scvtf v18.4s, v18.4s, #0x4\n" - "fmla v10.4s, v26.4s, v17.4s\n" - "ldr q17, [x23, #0x10]\n" - "fmul v26.4s, v27.4s, v20.s[2]\n" - "fmul v20.4s, v27.4s, v20.s[3]\n" - "fmla v12.4s, v0.4s, v26.4s\n" - "ldr d0, [x22, #-0x8]\n" - "ldr d26, [x21, #-0x8]\n" - "fcvtl v0.4s, v0.4h\n" - "fmla v28.4s, v18.4s, v20.4s\n" - "movi v20.4s, #0x0\n" - "movi v18.4s, #0x0\n" - ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n" - ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n" - "ldr q19, [x23, #0x20]\n" - "fcvtl v26.4s, v26.4h\n" - ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n" - ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n" - "ldr q19, [x23, #0x40]\n" - ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n" - ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n" - "ldr q19, [x23, #0x60]\n" - ".inst 0x4e9da674 // smmla v20.4s, v19.16b, v29.16b\n" - ".inst 0x4e83a672 // smmla v18.4s, v19.16b, v3.16b\n" - "uzp1 v19.2d, v20.2d, v18.2d\n" - "scvtf v19.4s, v19.4s, #0x4\n" - "uzp2 v20.2d, v20.2d, v18.2d\n" - "fmul v18.4s, v27.4s, v9.s[0]\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "fmla v11.4s, v19.4s, v18.4s\n" - "ldr q18, [x22, #0x0]\n" - "fmul v19.4s, v27.4s, v9.s[1]\n" - "fmla v13.4s, v20.4s, v19.4s\n" - "movi v19.4s, #0x0\n" - "movi v20.4s, #0x0\n" - ".inst 0x4e88a633 // smmla v19.4s, v17.16b, v8.16b\n" - ".inst 0x4e9fa634 // smmla v20.4s, v17.16b, v31.16b\n" - "ldr q17, [x23, #0x30]\n" - ".inst 0x4e8fa633 // smmla v19.4s, v17.16b, v15.16b\n" - ".inst 0x4e81a634 // smmla v20.4s, v17.16b, v1.16b\n" - "ldr q17, [x23, #0x50]\n" - ".inst 0x4e95a633 // smmla v19.4s, v17.16b, v21.16b\n" - ".inst 0x4e90a634 // smmla v20.4s, v17.16b, v16.16b\n" - "ldr q17, [x23, #0x70]\n" - "add x23, x23, #0x88\n" - ".inst 0x4e9da633 // smmla v19.4s, v17.16b, v29.16b\n" - ".inst 0x4e83a634 // smmla v20.4s, v17.16b, v3.16b\n" - "uzp1 v17.2d, v19.2d, v20.2d\n" - "scvtf v17.4s, v17.4s, #0x4\n" - "uzp2 v20.2d, v19.2d, v20.2d\n" - "fmul v19.4s, v27.4s, v9.s[2]\n" - "fmul v9.4s, v27.4s, v9.s[3]\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "fmla v22.4s, v17.4s, v19.4s\n" - "ldr q17, [x22, #0x10]\n" - "movi v19.4s, #0x0\n" - ".inst 0x4e88a653 // smmla v19.4s, v18.16b, v8.16b\n" - "fmla v23.4s, v20.4s, v9.4s\n" - "movi v20.4s, #0x0\n" - "movi v9.4s, #0x0\n" - ".inst 0x4e9fa654 // smmla v20.4s, v18.16b, v31.16b\n" - "ldr q18, [x22, #0x20]\n" - ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n" - ".inst 0x4e8fa653 // smmla v19.4s, v18.16b, v15.16b\n" - ".inst 0x4e81a654 // smmla v20.4s, v18.16b, v1.16b\n" - "ldr q18, [x22, #0x40]\n" - ".inst 0x4e95a653 // smmla v19.4s, v18.16b, v21.16b\n" - ".inst 0x4e90a654 // smmla v20.4s, v18.16b, v16.16b\n" - "ldr q18, [x22, #0x60]\n" - ".inst 0x4e9da653 // smmla v19.4s, v18.16b, v29.16b\n" - ".inst 0x4e83a654 // smmla v20.4s, v18.16b, v3.16b\n" - "movi v18.4s, #0x0\n" - ".inst 0x4e9fa632 // smmla v18.4s, v17.16b, v31.16b\n" - "ldr q17, [x22, #0x30]\n" - ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n" - ".inst 0x4e81a632 // smmla v18.4s, v17.16b, v1.16b\n" - "ldr q17, [x22, #0x50]\n" - ".inst 0x4e95a629 // smmla v9.4s, v17.16b, v21.16b\n" - ".inst 0x4e90a632 // smmla v18.4s, v17.16b, v16.16b\n" - "ldr q17, [x22, #0x70]\n" - "add x22, x22, #0x88\n" - ".inst 0x4e9da629 // smmla v9.4s, v17.16b, v29.16b\n" - ".inst 0x4e83a632 // smmla v18.4s, v17.16b, v3.16b\n" - "uzp1 v17.2d, v19.2d, v20.2d\n" - "uzp2 v20.2d, v19.2d, v20.2d\n" - "fmul v19.4s, v27.4s, v0.s[0]\n" - "scvtf v17.4s, v17.4s, #0x4\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "fmla v25.4s, v17.4s, v19.4s\n" - "ldr q19, [x21, #0x0]\n" - "fmul v17.4s, v27.4s, v0.s[1]\n" - "fmla v5.4s, v20.4s, v17.4s\n" - "ldr q17, [x21, #0x10]\n" - "uzp1 v20.2d, v9.2d, v18.2d\n" - "uzp2 v9.2d, v9.2d, v18.2d\n" - "fmul v18.4s, v27.4s, v0.s[2]\n" - "fmul v0.4s, v27.4s, v0.s[3]\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "scvtf v9.4s, v9.4s, #0x4\n" - "fmla v7.4s, v20.4s, v18.4s\n" - "movi v20.4s, #0x0\n" - "movi v18.4s, #0x0\n" - ".inst 0x4e88a674 // smmla v20.4s, v19.16b, v8.16b\n" - ".inst 0x4e9fa672 // smmla v18.4s, v19.16b, v31.16b\n" - "ldr q19, [x21, #0x20]\n" - "fmla v4.4s, v9.4s, v0.4s\n" - "movi v9.4s, #0x0\n" - "movi v0.4s, #0x0\n" - ".inst 0x4e88a629 // smmla v9.4s, v17.16b, v8.16b\n" - "fmul v8.4s, v27.4s, v26.s[0]\n" - ".inst 0x4e9fa620 // smmla v0.4s, v17.16b, v31.16b\n" - "ldr q17, [x21, #0x30]\n" - ".inst 0x4e8fa674 // smmla v20.4s, v19.16b, v15.16b\n" - "fmul v31.4s, v27.4s, v26.s[1]\n" - ".inst 0x4e81a672 // smmla v18.4s, v19.16b, v1.16b\n" - "ldr q19, [x21, #0x40]\n" - ".inst 0x4e8fa629 // smmla v9.4s, v17.16b, v15.16b\n" - "fmul v15.4s, v27.4s, v26.s[2]\n" - "fmul v27.4s, v27.4s, v26.s[3]\n" - ".inst 0x4e81a620 // smmla v0.4s, v17.16b, v1.16b\n" - "ldr q1, [x21, #0x50]\n" - ".inst 0x4e95a674 // smmla v20.4s, v19.16b, v21.16b\n" - ".inst 0x4e90a672 // smmla v18.4s, v19.16b, v16.16b\n" - "ldr q26, [x21, #0x60]\n" - ".inst 0x4e95a429 // smmla v9.4s, v1.16b, v21.16b\n" - ".inst 0x4e90a420 // smmla v0.4s, v1.16b, v16.16b\n" - "ldr q21, [x21, #0x70]\n" - "add x21, x21, #0x88\n" - ".inst 0x4e9da754 // smmla v20.4s, v26.16b, v29.16b\n" - ".inst 0x4e83a752 // smmla v18.4s, v26.16b, v3.16b\n" - ".inst 0x4e9da6a9 // smmla v9.4s, v21.16b, v29.16b\n" - ".inst 0x4e83a6a0 // smmla v0.4s, v21.16b, v3.16b\n" - "uzp1 v29.2d, v20.2d, v18.2d\n" - "uzp2 v21.2d, v20.2d, v18.2d\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "uzp1 v18.2d, v9.2d, v0.2d\n" - "uzp2 v16.2d, v9.2d, v0.2d\n" - "scvtf v21.4s, v21.4s, #0x4\n" - "fmla v6.4s, v29.4s, v8.4s\n" - "scvtf v18.4s, v18.4s, #0x4\n" - "scvtf v16.4s, v16.4s, #0x4\n" - "fmla v30.4s, v21.4s, v31.4s\n" - "fmla v24.4s, v18.4s, v15.4s\n" - "fmla v14.4s, v16.4s, v27.4s\n" - "bgt 3b\n" - "mov x20, %x[res_ptr]\n" - "subs x27, x27, #0x4\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "str q2, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q10, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q12, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q28, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q11, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q13, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q22, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q23, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q25, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q5, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q7, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q4, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q6, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q30, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q24, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "str q14, [x20, #0x0]\n" - "bne 2b\n" - "mov x20, #0x4\n" - "sub x10, x10, #0x10\n" - "cmp x10, #0x10\n" - "mov %x[res_ptr], x26\n" - "madd %x[a_ptr], x20, x9, %x[a_ptr]\n" - "bge 1b\n" - "4:" // Row loop skip - "cbz x10, 9f\n" - "5:" // Row tail: Row loop - "add x24, %x[b_ptr], #0x8\n" - "mov x23, %x[nc]\n" - "add x22, %x[res_ptr], %x[res_stride], LSL #2\n" - "6:" // Row tail: Column loop - "movi v2.16b, #0x0\n" - "movi v10.16b, #0x0\n" - "add x25, %x[a_ptr], #0x8\n" - "mov x21, %x[nb]\n" - "movi v12.16b, #0x0\n" - "movi v28.16b, #0x0\n" - "7:" // Row tail: Block loop - "ldr q6, [x24, #0x0]\n" - "ldr q5, [x24, #0x10]\n" - "movi v17.16b, #0x4\n" - "movi v8.4s, #0x0\n" - "ldr q4, [x25, #0x0]\n" - "ldr q13, [x25, #0x10]\n" - "movi v27.4s, #0x0\n" - "movi v0.4s, #0x0\n" - "ldr q31, [x24, #0x20]\n" - "ldr q14, [x24, #0x30]\n" - "movi v29.4s, #0x0\n" - "movi v22.16b, #0xf0\n" - "ldr q11, [x25, #0x20]\n" - "ldr q23, [x25, #0x30]\n" - "sshl v21.16b, v6.16b, v17.16b\n" - "sshl v16.16b, v5.16b, v17.16b\n" - "ldr q20, [x25, #0x40]\n" - "ldr q26, [x25, #0x50]\n" - "and v6.16b, v6.16b, v22.16b\n" - "and v5.16b, v5.16b, v22.16b\n" - "ldr q25, [x25, #0x60]\n" - "ldr q3, [x25, #0x70]\n" - "sshl v19.16b, v31.16b, v17.16b\n" - "sshl v18.16b, v14.16b, v17.16b\n" - "ldr d17, [x25, #-0x8]\n" - ".inst 0x4e95a488 // smmla v8.4s, v4.16b, v21.16b\n" - ".inst 0x4e90a49b // smmla v27.4s, v4.16b, v16.16b\n" - "and v31.16b, v31.16b, v22.16b\n" - ".inst 0x4e95a5a0 // smmla v0.4s, v13.16b, v21.16b\n" - ".inst 0x4e90a5bd // smmla v29.4s, v13.16b, v16.16b\n" - "and v14.16b, v14.16b, v22.16b\n" - "sub x20, x24, #0x8\n" - "ldr d16, [x20, #0x0]\n" - "subs x21, x21, #0x1\n" - "add x25, x25, #0x88\n" - "fcvtl v17.4s, v17.4h\n" - "add x24, x24, #0x48\n" - ".inst 0x4e93a568 // smmla v8.4s, v11.16b, v19.16b\n" - ".inst 0x4e92a57b // smmla v27.4s, v11.16b, v18.16b\n" - ".inst 0x4e93a6e0 // smmla v0.4s, v23.16b, v19.16b\n" - ".inst 0x4e92a6fd // smmla v29.4s, v23.16b, v18.16b\n" - "fcvtl v16.4s, v16.4h\n" - ".inst 0x4e86a688 // smmla v8.4s, v20.16b, v6.16b\n" - ".inst 0x4e85a69b // smmla v27.4s, v20.16b, v5.16b\n" - "fmul v23.4s, v16.4s, v17.s[0]\n" - "fmul v21.4s, v16.4s, v17.s[1]\n" - "fmul v1.4s, v16.4s, v17.s[2]\n" - "fmul v20.4s, v16.4s, v17.s[3]\n" - ".inst 0x4e86a740 // smmla v0.4s, v26.16b, v6.16b\n" - ".inst 0x4e85a75d // smmla v29.4s, v26.16b, v5.16b\n" - ".inst 0x4e9fa728 // smmla v8.4s, v25.16b, v31.16b\n" - ".inst 0x4e8ea73b // smmla v27.4s, v25.16b, v14.16b\n" - ".inst 0x4e9fa460 // smmla v0.4s, v3.16b, v31.16b\n" - ".inst 0x4e8ea47d // smmla v29.4s, v3.16b, v14.16b\n" - "uzp1 v19.2d, v8.2d, v27.2d\n" - "uzp2 v18.2d, v8.2d, v27.2d\n" - "scvtf v19.4s, v19.4s, #0x4\n" - "uzp1 v17.2d, v0.2d, v29.2d\n" - "uzp2 v16.2d, v0.2d, v29.2d\n" - "scvtf v18.4s, v18.4s, #0x4\n" - "fmla v2.4s, v19.4s, v23.4s\n" - "scvtf v17.4s, v17.4s, #0x4\n" - "scvtf v16.4s, v16.4s, #0x4\n" - "fmla v10.4s, v18.4s, v21.4s\n" - "fmla v12.4s, v17.4s, v1.4s\n" - "fmla v28.4s, v16.4s, v20.4s\n" - "bgt 7b\n" - "mov x20, %x[res_ptr]\n" - "cmp x10, #0x1\n" - "str q2, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x10, #0x2\n" - "str q10, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x10, #0x3\n" - "str q12, [x20, #0x0]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "str q28, [x20, #0x0]\n" - "8:" // Row tail: Accumulator store skip - "subs x23, x23, #0x4\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "bne 6b\n" - "subs x10, x10, #0x4\n" - "add %x[a_ptr], %x[a_ptr], x9\n" - "mov %x[res_ptr], x22\n" - "bgt 5b\n" - "9:" // Row tail: Row loop skip - : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) - : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) - : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28" - ); - return; - } -#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) - float sumf[4][4]; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; - } - sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } -} - -static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; const int ncols_interleaved = 8; @@ -2653,420 +1009,7 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c UNUSED(ncols_interleaved); UNUSED(blocklen); -#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) -#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) - if (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) { - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - size_t res_stride = bs * sizeof(float); - - __asm__ __volatile__( - "mov x20, #0x4\n" - "mov x13, %x[nr]\n" - "mov z28.s, #-0x4\n" - "mov x12, #0x88\n" - "ptrue p1.b\n" - "whilelt p0.s, XZR, x20\n" - "cmp x13, #0x10\n" - "mul x12, %x[nb], x12\n" - "blt 4f\n" - "1:" // Row loop - "add x11, %x[b_ptr], #0x10\n" - "mov x10, %x[nc]\n" - "add x9, %x[res_ptr], %x[res_stride], LSL #4\n" - "2:" // Column loop - "add x28, %x[a_ptr], #0x8\n" - "mov z24.b, #0x0\n" - "mov z15.b, #0x0\n" - "mov x27, %x[nb]\n" - "add x26, x28, x12\n" - "mov z12.b, #0x0\n" - "mov z0.b, #0x0\n" - "add x25, x26, x12\n" - "mov z13.b, #0x0\n" - "mov z1.b, #0x0\n" - "add x24, x25, x12\n" - "mov z20.b, #0x0\n" - "mov z25.b, #0x0\n" - "mov z11.b, #0x0\n" - "mov z16.b, #0x0\n" - "mov z19.b, #0x0\n" - "mov z26.b, #0x0\n" - "mov z8.b, #0x0\n" - "mov z29.b, #0x0\n" - "mov z27.b, #0x0\n" - "mov z10.b, #0x0\n" - "3:" // Block loop - "ld1b { z30.b }, p1/Z, [x11]\n" - "ld1b { z21.b }, p1/Z, [x11, #1, MUL VL]\n" - "mov z18.s, #0x0\n" - "mov z7.s, #0x0\n" - "ld1rqb { z3.b }, p1/Z, [x28]\n" - "ld1rqb { z5.b }, p1/Z, [x28, #16]\n" - "mov z9.s, #0x0\n" - "mov z22.s, #0x0\n" - "ld1b { z4.b }, p1/Z, [x11, #2, MUL VL]\n" - "ld1b { z17.b }, p1/Z, [x11, #3, MUL VL]\n" - "sub x20, x11, #0x10\n" - "sub x23, x28, #0x8\n" - "lsl z31.b, z30.b, #0x4\n" - "lsl z6.b, z21.b, #0x4\n" - "ld1h { z23.s }, p1/Z, [x20]\n" - "sub x22, x26, #0x8\n" - "and z30.b, z30.b, #0xf0\n" - "and z21.b, z21.b, #0xf0\n" - "sub x21, x25, #0x8\n" - "sub x20, x24, #0x8\n" - "lsl z14.b, z4.b, #0x4\n" - "lsl z2.b, z17.b, #0x4\n" - "subs x27, x27, #0x1\n" - "add x11, x11, #0x90\n" - ".inst 0x451f9872 // smmla z18.s, z3.b, z31.b\n" - ".inst 0x45069867 // smmla z7.s, z3.b, z6.b\n" - "ld1rqb { z3.b }, p1/Z, [x28, #32]\n" - "and z4.b, z4.b, #0xf0\n" - ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n" - ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n" - "ld1rqb { z5.b }, p1/Z, [x28, #48]\n" - "and z17.b, z17.b, #0xf0\n" - "fcvt z23.s, p1/m, z23.h\n" - ".inst 0x450e9872 // smmla z18.s, z3.b, z14.b\n" - ".inst 0x45029867 // smmla z7.s, z3.b, z2.b\n" - "ld1rqb { z3.b }, p1/Z, [x28, #64]\n" - ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n" - ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n" - "ld1rqb { z5.b }, p1/Z, [x28, #80]\n" - "fscale z23.s, p1/m, z23.s, z28.s\n" - ".inst 0x451e9872 // smmla z18.s, z3.b, z30.b\n" - ".inst 0x45159867 // smmla z7.s, z3.b, z21.b\n" - "ld1rqb { z3.b }, p1/Z, [x28, #96]\n" - ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n" - ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n" - "ld1rqb { z5.b }, p1/Z, [x28, #112]\n" - "add x28, x28, #0x88\n" - ".inst 0x45049872 // smmla z18.s, z3.b, z4.b\n" - ".inst 0x45119867 // smmla z7.s, z3.b, z17.b\n" - "ld1h { z3.s }, p0/Z, [x23]\n" - ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n" - ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n" - "fcvt z3.s, p1/m, z3.h\n" - "uzp1 z5.d, z18.d, z7.d\n" - "uzp2 z18.d, z18.d, z7.d\n" - "mov z3.q, z3.q[0]\n" - "uzp1 z7.d, z9.d, z22.d\n" - "uzp2 z22.d, z9.d, z22.d\n" - "fmul z9.s, z23.s, z3.s[0]\n" - "scvtf z5.s, p1/m, z5.s\n" - "scvtf z18.s, p1/m, z18.s\n" - "scvtf z7.s, p1/m, z7.s\n" - "scvtf z22.s, p1/m, z22.s\n" - "fmla z24.s, p1/M, z5.s, z9.s\n" - "ld1rqb { z5.b }, p1/Z, [x26]\n" - "fmul z9.s, z23.s, z3.s[1]\n" - "fmla z15.s, p1/M, z18.s, z9.s\n" - "ld1rqb { z18.b }, p1/Z, [x26, #16]\n" - "fmul z9.s, z23.s, z3.s[2]\n" - "fmul z3.s, z23.s, z3.s[3]\n" - "fmla z12.s, p1/M, z7.s, z9.s\n" - "mov z9.s, #0x0\n" - "ld1h { z7.s }, p0/Z, [x22]\n" - ".inst 0x451f98a9 // smmla z9.s, z5.b, z31.b\n" - "fmla z0.s, p1/M, z22.s, z3.s\n" - "mov z22.s, #0x0\n" - "ld1h { z3.s }, p0/Z, [x21]\n" - ".inst 0x450698b6 // smmla z22.s, z5.b, z6.b\n" - "ld1rqb { z5.b }, p1/Z, [x26, #32]\n" - "fcvt z7.s, p1/m, z7.h\n" - "fcvt z3.s, p1/m, z3.h\n" - ".inst 0x450e98a9 // smmla z9.s, z5.b, z14.b\n" - ".inst 0x450298b6 // smmla z22.s, z5.b, z2.b\n" - "ld1rqb { z5.b }, p1/Z, [x26, #64]\n" - "mov z7.q, z7.q[0]\n" - "mov z3.q, z3.q[0]\n" - ".inst 0x451e98a9 // smmla z9.s, z5.b, z30.b\n" - ".inst 0x451598b6 // smmla z22.s, z5.b, z21.b\n" - "ld1rqb { z5.b }, p1/Z, [x26, #96]\n" - ".inst 0x450498a9 // smmla z9.s, z5.b, z4.b\n" - ".inst 0x451198b6 // smmla z22.s, z5.b, z17.b\n" - "uzp1 z5.d, z9.d, z22.d\n" - "scvtf z5.s, p1/m, z5.s\n" - "uzp2 z22.d, z9.d, z22.d\n" - "fmul z9.s, z23.s, z7.s[0]\n" - "scvtf z22.s, p1/m, z22.s\n" - "fmla z13.s, p1/M, z5.s, z9.s\n" - "ld1rqb { z9.b }, p1/Z, [x25]\n" - "fmul z5.s, z23.s, z7.s[1]\n" - "fmla z1.s, p1/M, z22.s, z5.s\n" - "mov z5.s, #0x0\n" - "mov z22.s, #0x0\n" - ".inst 0x451f9a45 // smmla z5.s, z18.b, z31.b\n" - ".inst 0x45069a56 // smmla z22.s, z18.b, z6.b\n" - "ld1rqb { z18.b }, p1/Z, [x26, #48]\n" - ".inst 0x450e9a45 // smmla z5.s, z18.b, z14.b\n" - ".inst 0x45029a56 // smmla z22.s, z18.b, z2.b\n" - "ld1rqb { z18.b }, p1/Z, [x26, #80]\n" - ".inst 0x451e9a45 // smmla z5.s, z18.b, z30.b\n" - ".inst 0x45159a56 // smmla z22.s, z18.b, z21.b\n" - "ld1rqb { z18.b }, p1/Z, [x26, #112]\n" - "add x26, x26, #0x88\n" - ".inst 0x45049a45 // smmla z5.s, z18.b, z4.b\n" - ".inst 0x45119a56 // smmla z22.s, z18.b, z17.b\n" - "uzp1 z18.d, z5.d, z22.d\n" - "scvtf z18.s, p1/m, z18.s\n" - "uzp2 z22.d, z5.d, z22.d\n" - "fmul z5.s, z23.s, z7.s[2]\n" - "fmul z7.s, z23.s, z7.s[3]\n" - "scvtf z22.s, p1/m, z22.s\n" - "fmla z20.s, p1/M, z18.s, z5.s\n" - "ld1rqb { z18.b }, p1/Z, [x25, #16]\n" - "ld1h { z5.s }, p0/Z, [x20]\n" - "fcvt z5.s, p1/m, z5.h\n" - "fmla z25.s, p1/M, z22.s, z7.s\n" - "mov z22.s, #0x0\n" - "mov z7.s, #0x0\n" - ".inst 0x451f9936 // smmla z22.s, z9.b, z31.b\n" - ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n" - "ld1rqb { z9.b }, p1/Z, [x25, #32]\n" - "mov z5.q, z5.q[0]\n" - ".inst 0x450e9936 // smmla z22.s, z9.b, z14.b\n" - ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n" - "ld1rqb { z9.b }, p1/Z, [x25, #64]\n" - ".inst 0x451e9936 // smmla z22.s, z9.b, z30.b\n" - ".inst 0x45159927 // smmla z7.s, z9.b, z21.b\n" - "ld1rqb { z9.b }, p1/Z, [x25, #96]\n" - ".inst 0x45049936 // smmla z22.s, z9.b, z4.b\n" - ".inst 0x45119927 // smmla z7.s, z9.b, z17.b\n" - "uzp1 z9.d, z22.d, z7.d\n" - "scvtf z9.s, p1/m, z9.s\n" - "uzp2 z22.d, z22.d, z7.d\n" - "fmul z7.s, z23.s, z3.s[0]\n" - "scvtf z22.s, p1/m, z22.s\n" - "fmla z11.s, p1/M, z9.s, z7.s\n" - "ld1rqb { z9.b }, p1/Z, [x24]\n" - "fmul z7.s, z23.s, z3.s[1]\n" - "fmla z16.s, p1/M, z22.s, z7.s\n" - "mov z22.s, #0x0\n" - "mov z7.s, #0x0\n" - ".inst 0x451f9a56 // smmla z22.s, z18.b, z31.b\n" - ".inst 0x45069a47 // smmla z7.s, z18.b, z6.b\n" - "ld1rqb { z18.b }, p1/Z, [x25, #48]\n" - ".inst 0x450e9a56 // smmla z22.s, z18.b, z14.b\n" - ".inst 0x45029a47 // smmla z7.s, z18.b, z2.b\n" - "ld1rqb { z18.b }, p1/Z, [x25, #80]\n" - ".inst 0x451e9a56 // smmla z22.s, z18.b, z30.b\n" - ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n" - "ld1rqb { z18.b }, p1/Z, [x25, #112]\n" - "add x25, x25, #0x88\n" - ".inst 0x45049a56 // smmla z22.s, z18.b, z4.b\n" - ".inst 0x45119a47 // smmla z7.s, z18.b, z17.b\n" - "uzp1 z18.d, z22.d, z7.d\n" - "scvtf z18.s, p1/m, z18.s\n" - "uzp2 z7.d, z22.d, z7.d\n" - "fmul z22.s, z23.s, z3.s[2]\n" - "fmul z3.s, z23.s, z3.s[3]\n" - "scvtf z7.s, p1/m, z7.s\n" - "fmla z19.s, p1/M, z18.s, z22.s\n" - "ld1rqb { z18.b }, p1/Z, [x24, #16]\n" - "fmul z22.s, z23.s, z5.s[0]\n" - "fmla z26.s, p1/M, z7.s, z3.s\n" - "mov z3.s, #0x0\n" - "mov z7.s, #0x0\n" - ".inst 0x451f9923 // smmla z3.s, z9.b, z31.b\n" - ".inst 0x45069927 // smmla z7.s, z9.b, z6.b\n" - "ld1rqb { z9.b }, p1/Z, [x24, #32]\n" - ".inst 0x450e9923 // smmla z3.s, z9.b, z14.b\n" - ".inst 0x45029927 // smmla z7.s, z9.b, z2.b\n" - "mov z9.s, #0x0\n" - ".inst 0x451f9a49 // smmla z9.s, z18.b, z31.b\n" - "mov z31.s, #0x0\n" - ".inst 0x45069a5f // smmla z31.s, z18.b, z6.b\n" - "ld1rqb { z6.b }, p1/Z, [x24, #48]\n" - "ld1rqb { z18.b }, p1/Z, [x24, #64]\n" - ".inst 0x450e98c9 // smmla z9.s, z6.b, z14.b\n" - "fmul z14.s, z23.s, z5.s[1]\n" - ".inst 0x450298df // smmla z31.s, z6.b, z2.b\n" - "ld1rqb { z6.b }, p1/Z, [x24, #80]\n" - "fmul z2.s, z23.s, z5.s[2]\n" - "fmul z23.s, z23.s, z5.s[3]\n" - ".inst 0x451e9a43 // smmla z3.s, z18.b, z30.b\n" - ".inst 0x45159a47 // smmla z7.s, z18.b, z21.b\n" - "ld1rqb { z5.b }, p1/Z, [x24, #96]\n" - ".inst 0x451e98c9 // smmla z9.s, z6.b, z30.b\n" - ".inst 0x451598df // smmla z31.s, z6.b, z21.b\n" - "ld1rqb { z18.b }, p1/Z, [x24, #112]\n" - "add x24, x24, #0x88\n" - ".inst 0x450498a3 // smmla z3.s, z5.b, z4.b\n" - ".inst 0x451198a7 // smmla z7.s, z5.b, z17.b\n" - ".inst 0x45049a49 // smmla z9.s, z18.b, z4.b\n" - ".inst 0x45119a5f // smmla z31.s, z18.b, z17.b\n" - "uzp1 z18.d, z3.d, z7.d\n" - "uzp2 z5.d, z3.d, z7.d\n" - "scvtf z18.s, p1/m, z18.s\n" - "uzp1 z6.d, z9.d, z31.d\n" - "uzp2 z9.d, z9.d, z31.d\n" - "scvtf z5.s, p1/m, z5.s\n" - "fmla z8.s, p1/M, z18.s, z22.s\n" - "scvtf z6.s, p1/m, z6.s\n" - "scvtf z9.s, p1/m, z9.s\n" - "fmla z29.s, p1/M, z5.s, z14.s\n" - "fmla z27.s, p1/M, z6.s, z2.s\n" - "fmla z10.s, p1/M, z9.s, z23.s\n" - "bgt 3b\n" - "mov x20, %x[res_ptr]\n" - "subs x10, x10, #0x8\n" - "add %x[res_ptr], %x[res_ptr], #0x20\n" - "st1w { z24.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z15.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z12.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z0.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z13.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z1.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z20.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z25.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z11.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z16.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z19.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z26.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z8.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z29.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z27.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "st1w { z10.s }, p1, [x20]\n" - "bne 2b\n" - "mov x20, #0x4\n" - "sub x13, x13, #0x10\n" - "cmp x13, #0x10\n" - "mov %x[res_ptr], x9\n" - "madd %x[a_ptr], x20, x12, %x[a_ptr]\n" - "bge 1b\n" - "4:" // Row loop skip - "cbz x13, 9f\n" - "5:" // Row tail: Row loop - "add x25, %x[b_ptr], #0x10\n" - "mov x24, %x[nc]\n" - "add x23, %x[res_ptr], %x[res_stride], LSL #2\n" - "6:" // Row tail: Column loop - "mov z24.b, #0x0\n" - "mov z15.b, #0x0\n" - "add x28, %x[a_ptr], #0x8\n" - "mov x22, %x[nb]\n" - "mov z12.b, #0x0\n" - "mov z0.b, #0x0\n" - "7:" // Row tail: Block loop - "ld1b { z3.b }, p1/Z, [x25]\n" - "ld1b { z6.b }, p1/Z, [x25, #1, MUL VL]\n" - "mov z2.s, #0x0\n" - "mov z25.s, #0x0\n" - "ld1rqb { z26.b }, p1/Z, [x28]\n" - "ld1rqb { z21.b }, p1/Z, [x28, #16]\n" - "mov z27.s, #0x0\n" - "mov z19.s, #0x0\n" - "ld1b { z29.b }, p1/Z, [x25, #2, MUL VL]\n" - "ld1b { z16.b }, p1/Z, [x25, #3, MUL VL]\n" - "sub x21, x25, #0x10\n" - "sub x20, x28, #0x8\n" - "lsl z20.b, z3.b, #0x4\n" - "lsl z4.b, z6.b, #0x4\n" - "ld1rqb { z10.b }, p1/Z, [x28, #32]\n" - "ld1rqb { z23.b }, p1/Z, [x28, #48]\n" - "and z3.b, z3.b, #0xf0\n" - "and z6.b, z6.b, #0xf0\n" - "ld1rqb { z11.b }, p1/Z, [x28, #64]\n" - "ld1rqb { z7.b }, p1/Z, [x28, #80]\n" - "lsl z8.b, z29.b, #0x4\n" - "lsl z14.b, z16.b, #0x4\n" - "ld1rqb { z18.b }, p1/Z, [x28, #96]\n" - "ld1rqb { z30.b }, p1/Z, [x28, #112]\n" - ".inst 0x45149b42 // smmla z2.s, z26.b, z20.b\n" - ".inst 0x45049b59 // smmla z25.s, z26.b, z4.b\n" - "and z29.b, z29.b, #0xf0\n" - "ld1h { z17.s }, p1/Z, [x21]\n" - ".inst 0x45149abb // smmla z27.s, z21.b, z20.b\n" - ".inst 0x45049ab3 // smmla z19.s, z21.b, z4.b\n" - "and z16.b, z16.b, #0xf0\n" - "ld1h { z4.s }, p0/Z, [x20]\n" - "subs x22, x22, #0x1\n" - "add x28, x28, #0x88\n" - "fcvt z17.s, p1/m, z17.h\n" - "add x25, x25, #0x90\n" - ".inst 0x45089942 // smmla z2.s, z10.b, z8.b\n" - ".inst 0x450e9959 // smmla z25.s, z10.b, z14.b\n" - "fcvt z4.s, p1/m, z4.h\n" - ".inst 0x45089afb // smmla z27.s, z23.b, z8.b\n" - ".inst 0x450e9af3 // smmla z19.s, z23.b, z14.b\n" - "fscale z17.s, p1/m, z17.s, z28.s\n" - "mov z4.q, z4.q[0]\n" - ".inst 0x45039962 // smmla z2.s, z11.b, z3.b\n" - ".inst 0x45069979 // smmla z25.s, z11.b, z6.b\n" - "fmul z23.s, z17.s, z4.s[0]\n" - "fmul z9.s, z17.s, z4.s[1]\n" - "fmul z21.s, z17.s, z4.s[2]\n" - "fmul z4.s, z17.s, z4.s[3]\n" - ".inst 0x450398fb // smmla z27.s, z7.b, z3.b\n" - ".inst 0x450698f3 // smmla z19.s, z7.b, z6.b\n" - ".inst 0x451d9a42 // smmla z2.s, z18.b, z29.b\n" - ".inst 0x45109a59 // smmla z25.s, z18.b, z16.b\n" - ".inst 0x451d9bdb // smmla z27.s, z30.b, z29.b\n" - ".inst 0x45109bd3 // smmla z19.s, z30.b, z16.b\n" - "uzp1 z31.d, z2.d, z25.d\n" - "uzp2 z13.d, z2.d, z25.d\n" - "scvtf z31.s, p1/m, z31.s\n" - "uzp1 z17.d, z27.d, z19.d\n" - "uzp2 z18.d, z27.d, z19.d\n" - "scvtf z13.s, p1/m, z13.s\n" - "fmla z24.s, p1/M, z31.s, z23.s\n" - "scvtf z17.s, p1/m, z17.s\n" - "scvtf z18.s, p1/m, z18.s\n" - "fmla z15.s, p1/M, z13.s, z9.s\n" - "fmla z12.s, p1/M, z17.s, z21.s\n" - "fmla z0.s, p1/M, z18.s, z4.s\n" - "bgt 7b\n" - "mov x20, %x[res_ptr]\n" - "cmp x13, #0x1\n" - "st1w { z24.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x13, #0x2\n" - "st1w { z15.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "cmp x13, #0x3\n" - "st1w { z12.s }, p1, [x20]\n" - "add x20, x20, %x[res_stride]\n" - "ble 8f\n" - "st1w { z0.s }, p1, [x20]\n" - "8:" // Row tail: Accumulator store skip - "subs x24, x24, #0x8\n" - "add %x[res_ptr], %x[res_ptr], #0x20\n" - "bne 6b\n" - "subs x13, x13, #0x4\n" - "add %x[a_ptr], %x[a_ptr], x12\n" - "mov %x[res_ptr], x23\n" - "bgt 5b\n" - "9:" // Row tail: Row loop skip - : [a_ptr] "+&r" (a_ptr), [res_ptr] "+&r" (res_ptr) - : [b_ptr] "r" (b_ptr), [nr] "r" (nr), [nb] "r" (nb), [res_stride] "r" (res_stride), [nc] "r" (nc) - : "cc", "memory", "p0", "p1", "x9", "x10", "x11", "x12", "x13", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28", "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31" - ); - return; - } -#endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) -#elif defined(__AVX2__) || defined(__AVX512F__) +#if defined(__AVX2__) || defined(__AVX512F__) { const block_q4_0x8 * b_ptr_start = (const block_q4_0x8 *)vx; const block_q8_0x4 * a_ptr_start = (const block_q8_0x4 *)vy; @@ -3239,22 +1182,15 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane // Resembles MMLAs into 2x2 matrices in ARM Version - __m512i iacc_mat_00_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1)); - __m512i iacc_mat_01_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1)); - __m512i iacc_mat_10_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1)); - __m512i iacc_mat_11_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1)); - __m512i iacc_mat_00_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2)); - __m512i iacc_mat_01_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2)); - __m512i iacc_mat_10_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2)); - __m512i iacc_mat_11_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2)); + const __m512i zero = _mm512_setzero_epi32(); + __m512i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1); + __m512i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1); + __m512i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1); + __m512i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1); + __m512i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2); + __m512i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2); + __m512i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2); + __m512i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2); // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); @@ -3430,22 +1366,15 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane // Resembles MMLAs into 2x2 matrices in ARM Version - __m512i iacc_mat_00_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1)); - __m512i iacc_mat_01_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1)); - __m512i iacc_mat_10_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1)); - __m512i iacc_mat_11_sp1 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1)); - __m512i iacc_mat_00_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2)); - __m512i iacc_mat_01_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2)); - __m512i iacc_mat_10_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2)); - __m512i iacc_mat_11_sp2 = - _mm512_add_epi32(_mm512_add_epi32(_mm512_add_epi32(mul_sum_i8_pairs_int32x16(lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), mul_sum_i8_pairs_int32x16(lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2)), mul_sum_i8_pairs_int32x16(lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2)); + const __m512i zero = _mm512_setzero_epi32(); + __m512i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_01_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_01_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_01_0_sp1, rhs_mat_014589CD_0_sp1); + __m512i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367ABEF_0_sp1); + __m512i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_014589CD_3_sp1), lhs_mat_23_2_sp1, rhs_mat_014589CD_2_sp1), lhs_mat_23_1_sp1, rhs_mat_014589CD_1_sp1), lhs_mat_23_0_sp1, rhs_mat_014589CD_0_sp1); + __m512i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp1, rhs_mat_2367ABEF_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367ABEF_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367ABEF_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367ABEF_0_sp1); + __m512i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_01_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_01_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_01_0_sp2, rhs_mat_014589CD_0_sp2); + __m512i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_01_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367ABEF_0_sp2); + __m512i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_014589CD_3_sp2), lhs_mat_23_2_sp2, rhs_mat_014589CD_2_sp2), lhs_mat_23_1_sp2, rhs_mat_014589CD_1_sp2), lhs_mat_23_0_sp2, rhs_mat_014589CD_0_sp2); + __m512i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(mul_sum_i8_pairs_acc_int32x16(zero, lhs_mat_23_3_sp2, rhs_mat_2367ABEF_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367ABEF_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367ABEF_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367ABEF_0_sp2); // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block __m512i iacc_mat_00 = _mm512_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); @@ -3605,22 +1534,15 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane // Resembles MMLAs into 2x2 matrices in ARM Version - __m256i iacc_mat_00_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1)); - __m256i iacc_mat_01_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1)); - __m256i iacc_mat_10_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1)); - __m256i iacc_mat_11_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1)); - __m256i iacc_mat_00_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2)); - __m256i iacc_mat_01_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2)); - __m256i iacc_mat_10_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2)); - __m256i iacc_mat_11_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2)); + const __m256i zero = _mm256_setzero_si256(); + __m256i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1); + __m256i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1); + __m256i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1); + __m256i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1); + __m256i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2); + __m256i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2); + __m256i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2); + __m256i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2); // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); @@ -3769,22 +1691,15 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane // Resembles MMLAs into 2x2 matrices in ARM Version - __m256i iacc_mat_00_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1)); - __m256i iacc_mat_01_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1)); - __m256i iacc_mat_10_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1)); - __m256i iacc_mat_11_sp1 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1)); - __m256i iacc_mat_00_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2)); - __m256i iacc_mat_01_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2)); - __m256i iacc_mat_10_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2)); - __m256i iacc_mat_11_sp2 = - _mm256_add_epi32(_mm256_add_epi32(_mm256_add_epi32(mul_sum_i8_pairs_int32x8(lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), mul_sum_i8_pairs_int32x8(lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2)), mul_sum_i8_pairs_int32x8(lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2)); + const __m256i zero = _mm256_setzero_si256(); + __m256i iacc_mat_00_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_01_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_01_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_01_0_sp1, rhs_mat_0145_0_sp1); + __m256i iacc_mat_01_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_01_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_01_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_01_0_sp1, rhs_mat_2367_0_sp1); + __m256i iacc_mat_10_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_0145_3_sp1), lhs_mat_23_2_sp1, rhs_mat_0145_2_sp1), lhs_mat_23_1_sp1, rhs_mat_0145_1_sp1), lhs_mat_23_0_sp1, rhs_mat_0145_0_sp1); + __m256i iacc_mat_11_sp1 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp1, rhs_mat_2367_3_sp1), lhs_mat_23_2_sp1, rhs_mat_2367_2_sp1), lhs_mat_23_1_sp1, rhs_mat_2367_1_sp1), lhs_mat_23_0_sp1, rhs_mat_2367_0_sp1); + __m256i iacc_mat_00_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_01_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_01_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_01_0_sp2, rhs_mat_0145_0_sp2); + __m256i iacc_mat_01_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_01_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_01_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_01_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_01_0_sp2, rhs_mat_2367_0_sp2); + __m256i iacc_mat_10_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_0145_3_sp2), lhs_mat_23_2_sp2, rhs_mat_0145_2_sp2), lhs_mat_23_1_sp2, rhs_mat_0145_1_sp2), lhs_mat_23_0_sp2, rhs_mat_0145_0_sp2); + __m256i iacc_mat_11_sp2 = mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(mul_sum_i8_pairs_acc_int32x8(zero, lhs_mat_23_3_sp2, rhs_mat_2367_3_sp2), lhs_mat_23_2_sp2, rhs_mat_2367_2_sp2), lhs_mat_23_1_sp2, rhs_mat_2367_1_sp2), lhs_mat_23_0_sp2, rhs_mat_2367_0_sp2); // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block __m256i iacc_mat_00 = _mm256_add_epi32(iacc_mat_00_sp1, iacc_mat_00_sp2); @@ -3817,207 +1732,7 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c } return; } -#elif defined(__riscv_v_intrinsic) - if (__riscv_vlenb() >= QK4_0) { - const size_t vl = QK4_0; - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); - vfloat32m1_t sumf0 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); - vfloat32m1_t sumf1 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); - vfloat32m1_t sumf2 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); - vfloat32m1_t sumf3 = __riscv_vfmv_v_f_f32m1(0.0, vl / 4); - for (int l = 0; l < nb; l++) { - const vint8m4_t rhs_raw_vec = __riscv_vle8_v_i8m4((const int8_t *)b_ptr[l].qs, vl * 4); - const vint8m4_t rhs_vec_lo = __riscv_vsra_vx_i8m4(__riscv_vsll_vx_i8m4(rhs_raw_vec, 4, vl * 4), 4, vl * 4); - const vint8m4_t rhs_vec_hi = __riscv_vsra_vx_i8m4(rhs_raw_vec, 4, vl * 4); - const vint8m2_t rhs_vec_lo_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 0); - const vint8m2_t rhs_vec_lo_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_lo, 1); - const vint8m2_t rhs_vec_hi_0 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 0); - const vint8m2_t rhs_vec_hi_1 = __riscv_vget_v_i8m4_i8m2(rhs_vec_hi, 1); - - // vector version needs Zvfhmin extension - const float a_scales[4] = { - GGML_FP16_TO_FP32(a_ptr[l].d[0]), - GGML_FP16_TO_FP32(a_ptr[l].d[1]), - GGML_FP16_TO_FP32(a_ptr[l].d[2]), - GGML_FP16_TO_FP32(a_ptr[l].d[3]) - }; - const float b_scales[8] = { - GGML_FP16_TO_FP32(b_ptr[l].d[0]), - GGML_FP16_TO_FP32(b_ptr[l].d[1]), - GGML_FP16_TO_FP32(b_ptr[l].d[2]), - GGML_FP16_TO_FP32(b_ptr[l].d[3]), - GGML_FP16_TO_FP32(b_ptr[l].d[4]), - GGML_FP16_TO_FP32(b_ptr[l].d[5]), - GGML_FP16_TO_FP32(b_ptr[l].d[6]), - GGML_FP16_TO_FP32(b_ptr[l].d[7]) - }; - const vfloat32m1_t b_scales_vec = __riscv_vle32_v_f32m1(b_scales, vl / 4); - - const int64_t A0 = *(const int64_t *)&a_ptr[l].qs[0]; - const int64_t A4 = *(const int64_t *)&a_ptr[l].qs[32]; - const int64_t A8 = *(const int64_t *)&a_ptr[l].qs[64]; - const int64_t Ac = *(const int64_t *)&a_ptr[l].qs[96]; - __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment - vint16m4_t sumi_l0; - { - const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A0, vl / 4)); - const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A4, vl / 4)); - const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A8, vl / 4)); - const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ac, vl / 4)); - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); - const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); - - sumi_l0 = sumi_hi_m; - } - - { - const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l0)); - const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); - const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); - const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); - const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); - const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); - const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); - const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); - const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); - const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); - const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); - const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); - - const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[0], vl / 4); - sumf0 = __riscv_vfmacc_vv_f32m1(sumf0, tmp1, b_scales_vec, vl / 4); - } - - const int64_t A1 = *(const int64_t *)&a_ptr[l].qs[8]; - const int64_t A5 = *(const int64_t *)&a_ptr[l].qs[40]; - const int64_t A9 = *(const int64_t *)&a_ptr[l].qs[72]; - const int64_t Ad = *(const int64_t *)&a_ptr[l].qs[104]; - __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment - vint16m4_t sumi_l1; - { - const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A1, vl / 4)); - const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A5, vl / 4)); - const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A9, vl / 4)); - const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ad, vl / 4)); - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); - const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); - - sumi_l1 = sumi_hi_m; - } - - { - const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l1)); - const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); - const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); - const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); - const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); - const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); - const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); - const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); - const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); - const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); - const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); - const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); - - const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[1], vl / 4); - sumf1 = __riscv_vfmacc_vv_f32m1(sumf1, tmp1, b_scales_vec, vl / 4); - } - - const int64_t A2 = *(const int64_t *)&a_ptr[l].qs[16]; - const int64_t A6 = *(const int64_t *)&a_ptr[l].qs[48]; - const int64_t Aa = *(const int64_t *)&a_ptr[l].qs[80]; - const int64_t Ae = *(const int64_t *)&a_ptr[l].qs[112]; - __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment - vint16m4_t sumi_l2; - { - const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A2, vl / 4)); - const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A6, vl / 4)); - const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Aa, vl / 4)); - const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ae, vl / 4)); - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); - const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); - - sumi_l2 = sumi_hi_m; - } - - { - const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l2)); - const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); - const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); - const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); - const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); - const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); - const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); - const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); - const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); - const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); - const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); - const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); - - const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[2], vl / 4); - sumf2 = __riscv_vfmacc_vv_f32m1(sumf2, tmp1, b_scales_vec, vl / 4); - } - - const int64_t A3 = *(const int64_t *)&a_ptr[l].qs[24]; - const int64_t A7 = *(const int64_t *)&a_ptr[l].qs[56]; - const int64_t Ab = *(const int64_t *)&a_ptr[l].qs[88]; - const int64_t Af = *(const int64_t *)&a_ptr[l].qs[120]; - __asm__ __volatile__("" ::: "memory"); // prevent gcc from emitting fused vlse64, violating alignment - vint16m4_t sumi_l3; - { - const vint8m2_t lhs_0_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A3, vl / 4)); - const vint8m2_t lhs_1_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(A7, vl / 4)); - const vint8m2_t lhs_2_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Ab, vl / 4)); - const vint8m2_t lhs_3_8 =__riscv_vreinterpret_v_i64m2_i8m2(__riscv_vmv_v_x_i64m2(Af, vl / 4)); - const vint16m4_t sumi_lo_0 = __riscv_vwmul_vv_i16m4(rhs_vec_lo_0, lhs_0_8, vl * 2); - const vint16m4_t sumi_lo_1 = __riscv_vwmacc_vv_i16m4(sumi_lo_0, rhs_vec_lo_1, lhs_1_8, vl * 2); - const vint16m4_t sumi_hi_0 = __riscv_vwmacc_vv_i16m4(sumi_lo_1, rhs_vec_hi_0, lhs_2_8, vl * 2); - const vint16m4_t sumi_hi_m = __riscv_vwmacc_vv_i16m4(sumi_hi_0, rhs_vec_hi_1, lhs_3_8, vl * 2); - - sumi_l3 = sumi_hi_m; - } - - { - const vuint32m4_t sumi_i32 = __riscv_vreinterpret_v_i32m4_u32m4(__riscv_vreinterpret_v_i16m4_i32m4(sumi_l3)); - const vuint16m2_t sumi_h2_0 = __riscv_vnsrl_wx_u16m2(sumi_i32, 0, vl); - const vuint16m2_t sumi_h2_1 = __riscv_vnsrl_wx_u16m2(sumi_i32, 16, vl); - const vuint16m2_t sumi_h2 = __riscv_vadd_vv_u16m2(sumi_h2_0, sumi_h2_1, vl); - const vuint32m2_t sumi_h2_i32 = __riscv_vreinterpret_v_u16m2_u32m2(sumi_h2); - const vuint16m1_t sumi_h4_0 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 0, vl / 2); - const vuint16m1_t sumi_h4_1 = __riscv_vnsrl_wx_u16m1(sumi_h2_i32, 16, vl / 2); - const vuint16m1_t sumi_h4 = __riscv_vadd_vv_u16m1(sumi_h4_0, sumi_h4_1, vl / 2); - const vuint32m1_t sumi_h4_i32 = __riscv_vreinterpret_v_u16m1_u32m1(sumi_h4); - const vint16mf2_t sumi_h8_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 0, vl / 4)); - const vint16mf2_t sumi_h8_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(sumi_h4_i32, 16, vl / 4)); - const vint32m1_t sumi_h8 = __riscv_vwadd_vv_i32m1(sumi_h8_0, sumi_h8_1, vl / 4); - const vfloat32m1_t facc = __riscv_vfcvt_f_x_v_f32m1(sumi_h8, vl / 4); - - const vfloat32m1_t tmp1 = __riscv_vfmul_vf_f32m1(facc, a_scales[3], vl / 4); - sumf3 = __riscv_vfmacc_vv_f32m1(sumf3, tmp1, b_scales_vec, vl / 4); - } - } - __riscv_vse32_v_f32m1(&s[(y * 4 + 0) * bs + x * ncols_interleaved], sumf0, vl / 4); - __riscv_vse32_v_f32m1(&s[(y * 4 + 1) * bs + x * ncols_interleaved], sumf1, vl / 4); - __riscv_vse32_v_f32m1(&s[(y * 4 + 2) * bs + x * ncols_interleaved], sumf2, vl / 4); - __riscv_vse32_v_f32m1(&s[(y * 4 + 3) * bs + x * ncols_interleaved], sumf3, vl / 4); - } - } - - return; - } #endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) float sumf[4][8]; int sumi; @@ -4053,7 +1768,7 @@ static void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, c } } -static void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; const int ncols_interleaved = 8; @@ -4076,7 +1791,7 @@ static void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, c UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined(__AVX2__) +#if defined(__AVX2__) || defined(__AVX512F__) const block_q4_Kx8 * b_ptr_start = (const block_q4_Kx8 * ) vx; const block_q8_Kx4 * a_ptr_start = (const block_q8_Kx4 * ) vy; int64_t b_nb = n / QK_K; @@ -4086,8 +1801,748 @@ static void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, c const __m256i m4b = _mm256_set1_epi8(0x0F); // Permute mask used for easier vector processing at later stages __m256i requiredOrder = _mm256_set_epi32(3, 2, 1, 0, 7, 6, 5, 4); - + int64_t xstart = 0; int anr = nr - nr % 16;; // Used to align nr with boundary of 16 +#ifdef __AVX512F__ + int anc = nc - nc % 16; // Used to align nc with boundary of 16 + // Mask to mask out nibbles from packed bytes expanded to 512 bit length + const __m512i m4bexpanded = _mm512_set1_epi8(0x0F); + //Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation + for (; y < anr / 4; y += 4) { + + const block_q8_Kx4 * a_ptrs[4]; + + a_ptrs[0] = a_ptr_start + (y * nb); + for (int i = 0; i < 3; ++i) { + a_ptrs[i + 1] = a_ptrs[i] + nb; + } + + // Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation + for (int64_t x = 0; x < anc / 8; x += 2) { + + const block_q4_Kx8 * b_ptr_0 = b_ptr_start + ((x) * b_nb); + const block_q4_Kx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb); + + // Master FP accumulators + __m512 acc_rows[16]; + for (int i = 0; i < 16; i++) { + acc_rows[i] = _mm512_setzero_ps(); + } + + __m512 acc_min_rows[16]; + for (int i = 0; i < 16; i++) { + acc_min_rows[i] = _mm512_setzero_ps(); + } + + // For super block + for (int64_t b = 0; b < nb; b++) { + // Scale values - Load the sixteen scale values from two block_q4_kx8 structures + const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); + + // dmin values - Load the sixteen dmin values from two block_q4_kx8 structures + const __m512 col_dmin_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].dmin, b_ptr_1[b].dmin); + + // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration + for (int sb = 0; sb < QK_K / 64; sb++) { + + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + sb * 256)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 32 + sb * 256)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 64 + sb * 256)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 96 + sb * 256)); + const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 128 + sb * 256)); + const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 160 + sb * 256)); + const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 192 + sb * 256)); + const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 224 + sb * 256)); + + const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + sb * 256)); + const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 32 + sb * 256)); + const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 64 + sb * 256)); + const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 96 + sb * 256)); + const __m256i rhs_raw_mat_89AB_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 128 + sb * 256)); + const __m256i rhs_raw_mat_CDEF_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 160 + sb * 256)); + const __m256i rhs_raw_mat_89AB_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 192 + sb * 256)); + const __m256i rhs_raw_mat_CDEF_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 224 + sb * 256)); + + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240); + const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240); + + const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240); + const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240); + const __m256i rhs_raw_mat_89CD_2 = _mm256_blend_epi32(rhs_raw_mat_89AB_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_2, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_2, requiredOrder), rhs_raw_mat_CDEF_2, 240); + const __m256i rhs_raw_mat_89CD_3 = _mm256_blend_epi32(rhs_raw_mat_89AB_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_3, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_3, requiredOrder), rhs_raw_mat_CDEF_3, 240); + + const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1); + const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1); + const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1); + const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1); + + const __m512i rhs_raw_mat_014589CD_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_2), rhs_raw_mat_89CD_2, 1); + const __m512i rhs_raw_mat_2367ABEF_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_2), rhs_raw_mat_ABEF_2, 1); + const __m512i rhs_raw_mat_014589CD_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_3), rhs_raw_mat_89CD_3, 1); + const __m512i rhs_raw_mat_2367ABEF_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_3), rhs_raw_mat_ABEF_3, 1); + + //4-bit -> 8-bit + const __m512i rhs_mat_014589CD_00 = _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded); //B00(0-7) B01(0-7) B04(0-7) B05(0-7) B08(0-7) B09(0-7) B0C(0-7) B0D(0-7) + const __m512i rhs_mat_2367ABEF_00 = _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded); //B02(0-7) B03(0-7) B06(0-7) B07(0-7) B0A(0-7) B0B(0-7) B0E(0-7) B0F(0-7) + const __m512i rhs_mat_014589CD_01 = _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded); //B00(8-15) B01(8-15) B04(8-15) B05(8-15) B08(8-15) B09(8-15) B0C(8-15) B0D(8-15) + const __m512i rhs_mat_2367ABEF_01 = _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded); //B02(8-15) B03(8-15) B06(8-15) B07(8-15) B0A(8-15) B0B(8-15) B0E(8-15) B0F(8-15) + + const __m512i rhs_mat_014589CD_02 = _mm512_and_si512(rhs_raw_mat_014589CD_2, m4bexpanded); //B00(16-23) B01(16-23) B04(16-23) B05(16-23) B08(16-23) B09(16-23) B0C(16-23) B0D(16-23) + const __m512i rhs_mat_2367ABEF_02 = _mm512_and_si512(rhs_raw_mat_2367ABEF_2, m4bexpanded); //B02(16-23) B03(16-23) B06(16-23) B07(16-23) B0A(16-23) B0B(16-23) B0E(16-23) B0F(16-23) + const __m512i rhs_mat_014589CD_03 = _mm512_and_si512(rhs_raw_mat_014589CD_3, m4bexpanded); //B00(24-31) B01(24-31) B04(24-31) B05(24-31) B08(24-31) B09(24-31) B0C(24-31) B0D(24-31) + const __m512i rhs_mat_2367ABEF_03 = _mm512_and_si512(rhs_raw_mat_2367ABEF_3, m4bexpanded); //B02(24-31) B03(24-31) B06(24-31) B07(24-31) B0A(24-31) B0B(24-31) B0E(24-31) B0F(24-31) + + const __m512i rhs_mat_014589CD_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded); //B10(0-7) B11(0-7) B14(0-7) B15(0-7) B18(0-7) B19(0-7) B1C(0-7) B1D(0-7) + const __m512i rhs_mat_2367ABEF_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded); //B12(0-7) B13(0-7) B16(0-7) B17(0-7) B1A(0-7) B1B(0-7) B1E(0-7) B1F(0-7) + const __m512i rhs_mat_014589CD_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded); //B10(8-15) B11(8-15) B14(8-15) B15(8-15) B18(8-15) B19(8-15) B1C(8-15) B1D(8-15) + const __m512i rhs_mat_2367ABEF_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded); //B12(8-15) B13(8-15) B16(8-15) B17(8-15) B1A(8-15) B1B(8-15) B1E(8-15) B1F(8-15) + + const __m512i rhs_mat_014589CD_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 4), m4bexpanded); //B10(16-23) B11(16-23) B14(16-23) B15(16-23) B18(16-23) B19(16-23) B1C(16-23) B1D(16-23) + const __m512i rhs_mat_2367ABEF_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 4), m4bexpanded); //B12(16-23) B13(16-23) B16(16-23) B17(16-23) B1A(16-23) B1B(16-23) B1E(16-23) B1F(16-23) + const __m512i rhs_mat_014589CD_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 4), m4bexpanded); //B10(24-31) B11(24-31) B14(24-31) B15(24-31) B18(24-31) B19(24-31) B1C(24-31) B1D(24-31) + const __m512i rhs_mat_2367ABEF_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 4), m4bexpanded); //B12(24-31) B13(24-31) B16(24-31) B17(24-31) B1A(24-31) B1B(24-31) B1E(24-31) B1F(24-31) + + // Shuffle pattern one - right side input + const __m512i rhs_mat_014589CD_00_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3) B08(0-3) B09(0-3) B08(0-3) B09(0-3) B0C(0-3) B0D(0-3) B0C(0-3) B0D(0-3) + const __m512i rhs_mat_2367ABEF_00_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3) B0A(0-3) B0B(0-3) B0A(0-3) B0B(0-3) B0E(0-3) B0F(0-3) B0E(0-3) B0F(0-3) + const __m512i rhs_mat_014589CD_01_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) B08(8-11) B09(8-11) B08(8-11) B09(8-11) B0C(8-11) B0D(8-11) B0C(8-11) B0D(8-11) + const __m512i rhs_mat_2367ABEF_01_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11) B0A(8-11) B0B(8-11) B0A(8-11) B0B(8-11) B0E(8-11) B0F(8-11) B0E(8-11) B0F(8-11) + const __m512i rhs_mat_014589CD_02_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19) B08(16-19) B09(16-19) B08(16-19) B09(16-19) B0C(16-19) B0D(16-19) B0C(16-19) B0D(16-19) + const __m512i rhs_mat_2367ABEF_02_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19) B0A(16-19) B0B(16-19) B0A(16-19) B0B(16-19) B0E(16-19) B0F(16-19) B0E(16-19) B0F(16-19) + const __m512i rhs_mat_014589CD_03_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27) B08(24-27) B09(24-27) B08(24-27) B09(24-27) B0C(24-27) B0D(24-27) B0C(24-27) B0D(24-27) + const __m512i rhs_mat_2367ABEF_03_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27) B0A(24-27) B0B(24-27) B0A(24-27) B0B(24-27) B0E(24-27) B0F(24-27) B0E(24-27) B0F(24-27) + + const __m512i rhs_mat_014589CD_10_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3) B18(0-3) B19(0-3) B18(0-3) B19(0-3) B1C(0-3) B1D(0-3) B1C(0-3) B1D(0-3) + const __m512i rhs_mat_2367ABEF_10_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3) B1A(0-3) B1B(0-3) B1A(0-3) B1B(0-3) B1E(0-3) B1F(0-3) B1E(0-3) B1F(0-3) + const __m512i rhs_mat_014589CD_11_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11) B18(8-11) B19(8-11) B18(8-11) B19(8-11) B1C(8-11) B1D(8-11) B1C(8-11) B1D(8-11) + const __m512i rhs_mat_2367ABEF_11_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11) B1A(8-11) B1B(8-11) B1A(8-11) B1B(8-11) B1E(8-11) B1F(8-11) B1E(8-11) B1F(8-11) + const __m512i rhs_mat_014589CD_12_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19) B18(16-19) B19(16-19) B18(16-19) B19(16-19) B1C(16-19) B1D(16-19) B1C(16-19) B1D(16-19) + const __m512i rhs_mat_2367ABEF_12_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19) B1A(16-19) B1B(16-19) B1A(16-19) B1B(16-19) B1E(16-19) B1F(16-19) B1E(16-19) B1F(16-19) + const __m512i rhs_mat_014589CD_13_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27) B18(24-27) B19(24-27) B18(24-27) B19(24-27) B1C(24-27) B1D(24-27) B1C(24-27) B1D(24-27) + const __m512i rhs_mat_2367ABEF_13_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27) B1A(24-27) B1B(24-27) B1A(24-27) B1B(24-27) B1E(24-27) B1F(24-27) B1E(24-27) B1F(24-27) + + // Shuffle pattern two - right side input + const __m512i rhs_mat_014589CD_00_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7) B08(4-7) B09(4-7) B08(4-7) B09(4-7) B0C(4-7) B0D(4-7) B0C(4-7) B0D(4-7) + const __m512i rhs_mat_2367ABEF_00_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7) B0A(4-7) B0B(4-7) B0A(4-7) B0B(4-7) B0E(4-7) B0F(4-7) B0E(4-7) B0F(4-7) + const __m512i rhs_mat_014589CD_01_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15) B08(12-15) B09(12-15) B08(12-15) B09(12-15) B0C(12-15) B0D(12-15) B0C(12-15) B0D(12-15) + const __m512i rhs_mat_2367ABEF_01_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15) B0A(12-15) B0B(12-15) B0A(12-15) B0B(12-15) B0E(12-15) B0F(12-15) B0E(12-15) B0F(12-15) + const __m512i rhs_mat_014589CD_02_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23) B08(20-23) B09(20-23) B08(20-23) B09(20-23) B0C(20-23) B0D(20-23) B0C(20-23) B0D(20-23) + const __m512i rhs_mat_2367ABEF_02_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23) B0A(20-23) B0B(20-23) B0A(20-23) B0B(20-23) B0E(20-23) B0F(20-23) B0E(20-23) B0F(20-23) + const __m512i rhs_mat_014589CD_03_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31) B08(28-31) B09(28-31) B08(28-31) B09(28-31) B0C(28-31) B0D(28-31) B0C(28-31) 0BD(28-31) + const __m512i rhs_mat_2367ABEF_03_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31) B0A(28-31) B0B(28-31) B0A(28-31) B0B(28-31) B0E(28-31) B0F(28-31) B0E(28-31) B0F(28-31) + + const __m512i rhs_mat_014589CD_10_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7) B18(4-7) B19(4-7) B18(4-7) B19(4-7) B1C(4-7) B1D(4-7) B1C(4-7) B1D(4-7) + const __m512i rhs_mat_2367ABEF_10_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7) B1A(4-7) B1B(4-7) B1A(4-7) B1B(4-7) B1E(4-7) B1F(4-7) B1E(4-7) B1F(4-7) + const __m512i rhs_mat_014589CD_11_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15) B18(12-15) B19(12-15) B18(12-15) B19(12-15) B1C(12-15) B1D(12-15) B1C(12-15) B1D(12-15) + const __m512i rhs_mat_2367ABEF_11_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15) B1A(12-15) B1B(12-15) B1A(12-15) B1B(12-15) B1E(12-15) B1F(12-15) B1E(12-15) B1F(12-15) + const __m512i rhs_mat_014589CD_12_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23) B18(20-23) B19(20-23) B18(20-23) B19(20-23) B1C(20-23) B1D(20-23) B1C(20-23) B1D(20-23) + const __m512i rhs_mat_2367ABEF_12_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23) B1A(20-23) B1B(20-23) B1A(20-23) B1B(20-23) B1E(20-23) B1F(20-23) B1E(20-23) B1F(20-23) + const __m512i rhs_mat_014589CD_13_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31) B18(28-31) B19(28-31) B18(28-31) B19(28-31) B1C(28-31) B1D(28-31) B1C(28-31) B1D(28-31) + const __m512i rhs_mat_2367ABEF_13_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31) B1A(28-31) B1B(28-31) B1A(28-31) B1B(28-31) B1E(28-31) B1F(28-31) B1E(28-31) B1F(28-31) + + uint32_t utmp_00[4], utmp_01[4], utmp_10[4], utmp_11[4]; + + // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together + // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop + memcpy(utmp_00, b_ptr_0[b].scales + 24 * sb, 12); + utmp_00[3] = ((utmp_00[2] >> 4) & kmask2) | (((utmp_00[1] >> 6) & kmask3) << 4); + const uint32_t uaux_00 = utmp_00[1] & kmask1; + utmp_00[1] = (utmp_00[2] & kmask2) | (((utmp_00[0] >> 6) & kmask3) << 4); + utmp_00[2] = uaux_00; + utmp_00[0] &= kmask1; + + // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop + memcpy(utmp_01, b_ptr_0[b].scales + 12 + sb * 24, 12); + utmp_01[3] = ((utmp_01[2] >> 4) & kmask2) | (((utmp_01[1] >> 6) & kmask3) << 4); + const uint32_t uaux_01 = utmp_01[1] & kmask1; + utmp_01[1] = (utmp_01[2] & kmask2) | (((utmp_01[0] >> 6) & kmask3) << 4); + utmp_01[2] = uaux_01; + utmp_01[0] &= kmask1; + + memcpy(utmp_10, b_ptr_1[b].scales + sb * 24, 12); + utmp_10[3] = ((utmp_10[2] >> 4) & kmask2) | (((utmp_10[1] >> 6) & kmask3) << 4); + const uint32_t uaux_10 = utmp_10[1] & kmask1; + utmp_10[1] = (utmp_10[2] & kmask2) | (((utmp_10[0] >> 6) & kmask3) << 4); + utmp_10[2] = uaux_10; + utmp_10[0] &= kmask1; + + // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop + memcpy(utmp_11, b_ptr_1[b].scales + 12 + sb * 24, 12); + utmp_11[3] = ((utmp_11[2] >> 4) & kmask2) | (((utmp_11[1] >> 6) & kmask3) << 4); + const uint32_t uaux_11 = utmp_11[1] & kmask1; + utmp_11[1] = (utmp_11[2] & kmask2) | (((utmp_11[0] >> 6) & kmask3) << 4); + utmp_11[2] = uaux_11; + utmp_11[0] &= kmask1; + + // Scales of first sub block in the sb loop + const __m256i mins_and_scales_0 = _mm256_set_epi32(utmp_10[3], utmp_10[2], utmp_10[1], utmp_10[0], utmp_00[3], utmp_00[2], utmp_00[1], utmp_00[0]); + const __m512i scales_0 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0)); + + // Scales of second sub block in the sb loop + const __m256i mins_and_scales_1 = _mm256_set_epi32(utmp_11[3], utmp_11[2], utmp_11[1], utmp_11[0], utmp_01[3], utmp_01[2], utmp_01[1], utmp_01[0]); + const __m512i scales_1 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1)); + + // Mins of first and second sub block of Q4_K block are arranged side by side + const __m512i mins_01 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(_mm256_shuffle_epi32(mins_and_scales_0, 78), _mm256_shuffle_epi32(mins_and_scales_1, 78))); + + const __m512i scale_014589CD_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)238); + + const __m512i scale_014589CD_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)238); + + for (int rp = 0; rp < 4; rp++) { + + // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into 512 bit vector + __m256i lhs_mat_ymm_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 256 * sb))); + __m256i lhs_mat_ymm_01_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 0); + __m256i lhs_mat_ymm_23_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 17); + __m256i lhs_mat_ymm_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 32 + 256 * sb))); + __m256i lhs_mat_ymm_01_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 0); + __m256i lhs_mat_ymm_23_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 17); + __m256i lhs_mat_ymm_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 64 + 256 * sb))); + __m256i lhs_mat_ymm_01_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 0); + __m256i lhs_mat_ymm_23_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 17); + __m256i lhs_mat_ymm_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 96 + 256 * sb))); + __m256i lhs_mat_ymm_01_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 0); + __m256i lhs_mat_ymm_23_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 17); + __m256i lhs_mat_ymm_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 128 + 256 * sb))); + __m256i lhs_mat_ymm_01_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 0); + __m256i lhs_mat_ymm_23_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 17); + __m256i lhs_mat_ymm_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 160 + 256 * sb))); + __m256i lhs_mat_ymm_01_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 0); + __m256i lhs_mat_ymm_23_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 17); + __m256i lhs_mat_ymm_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 192 + 256 * sb))); + __m256i lhs_mat_ymm_01_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 0); + __m256i lhs_mat_ymm_23_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 17); + __m256i lhs_mat_ymm_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].qs + 224 + 256 * sb))); + __m256i lhs_mat_ymm_01_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 0); + __m256i lhs_mat_ymm_23_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 17); + + __m512i lhs_mat_01_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_00), lhs_mat_ymm_01_00, 1); + __m512i lhs_mat_23_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_00), lhs_mat_ymm_23_00, 1); + __m512i lhs_mat_01_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_01), lhs_mat_ymm_01_01, 1); + __m512i lhs_mat_23_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_01), lhs_mat_ymm_23_01, 1); + __m512i lhs_mat_01_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_02), lhs_mat_ymm_01_02, 1); + __m512i lhs_mat_23_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_02), lhs_mat_ymm_23_02, 1); + __m512i lhs_mat_01_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_03), lhs_mat_ymm_01_03, 1); + __m512i lhs_mat_23_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_03), lhs_mat_ymm_23_03, 1); + + __m512i lhs_mat_01_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_10), lhs_mat_ymm_01_10, 1); + __m512i lhs_mat_23_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_10), lhs_mat_ymm_23_10, 1); + __m512i lhs_mat_01_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_11), lhs_mat_ymm_01_11, 1); + __m512i lhs_mat_23_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_11), lhs_mat_ymm_23_11, 1); + __m512i lhs_mat_01_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_12), lhs_mat_ymm_01_12, 1); + __m512i lhs_mat_23_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_12), lhs_mat_ymm_23_12, 1); + __m512i lhs_mat_01_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_13), lhs_mat_ymm_01_13, 1); + __m512i lhs_mat_23_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_13), lhs_mat_ymm_23_13, 1); + + // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks + __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptrs[rp][b].bsums + 16 * sb))); + __m256i lhs_bsums_hsum_ymm_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1))); + lhs_bsums_hsum_ymm_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_ymm_0123_01, lhs_bsums_hsum_ymm_0123_01, 0); + __m512i lhs_bsums_hsum_0123_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_hsum_ymm_0123_01), lhs_bsums_hsum_ymm_0123_01, 1); + + // Shuffle pattern one - left side input + const __m512i lhs_mat_01_00_sp1 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) + const __m512i lhs_mat_23_00_sp1 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)160); //A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) + const __m512i lhs_mat_01_01_sp1 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) + const __m512i lhs_mat_23_01_sp1 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)160); //A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) + const __m512i lhs_mat_01_02_sp1 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)160); //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) + const __m512i lhs_mat_23_02_sp1 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)160); //A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) + const __m512i lhs_mat_01_03_sp1 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)160); //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) + const __m512i lhs_mat_23_03_sp1 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)160); //A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) + + const __m512i lhs_mat_01_10_sp1 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) + const __m512i lhs_mat_23_10_sp1 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)160); //A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) + const __m512i lhs_mat_01_11_sp1 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) + const __m512i lhs_mat_23_11_sp1 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)160); //A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) + const __m512i lhs_mat_01_12_sp1 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)160); //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) + const __m512i lhs_mat_23_12_sp1 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)160); //A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) + const __m512i lhs_mat_01_13_sp1 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) + const __m512i lhs_mat_23_13_sp1 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)160); //A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) + + const __m512i lhs_mat_01_00_sp2 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) + const __m512i lhs_mat_23_00_sp2 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)245); //A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) + const __m512i lhs_mat_01_01_sp2 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) + const __m512i lhs_mat_23_01_sp2 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)245); //A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) + const __m512i lhs_mat_01_02_sp2 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) + const __m512i lhs_mat_23_02_sp2 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)245); //A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) + const __m512i lhs_mat_01_03_sp2 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) + const __m512i lhs_mat_23_03_sp2 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)245); //A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) + + const __m512i lhs_mat_01_10_sp2 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) + const __m512i lhs_mat_23_10_sp2 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)245); //A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) + const __m512i lhs_mat_01_11_sp2 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) + const __m512i lhs_mat_23_11_sp2 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)245); //A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) + const __m512i lhs_mat_01_12_sp2 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) + const __m512i lhs_mat_23_12_sp2 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)245); //A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) + const __m512i lhs_mat_01_13_sp2 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) + const __m512i lhs_mat_23_13_sp2 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)245); //A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + __m512i iacc_mat_00_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_01_00_sp1)); + __m512i iacc_mat_01_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_01_00_sp1)); + __m512i iacc_mat_10_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_23_00_sp1)); + __m512i iacc_mat_11_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_23_00_sp1)); + __m512i iacc_mat_00_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_01_10_sp1)); + __m512i iacc_mat_01_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_01_10_sp1)); + __m512i iacc_mat_10_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_23_10_sp1)); + __m512i iacc_mat_11_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_23_10_sp1)); + + __m512i iacc_mat_00_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_01_00_sp2)); + __m512i iacc_mat_01_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_01_00_sp2)); + __m512i iacc_mat_10_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_23_00_sp2)); + __m512i iacc_mat_11_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_23_00_sp2)); + __m512i iacc_mat_00_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_01_10_sp2)); + __m512i iacc_mat_01_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_01_10_sp2)); + __m512i iacc_mat_10_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_23_10_sp2)); + __m512i iacc_mat_11_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_23_10_sp2)); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + __m512i iacc_mat_00_0 = _mm512_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2); + __m512i iacc_mat_01_0 = _mm512_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2); + __m512i iacc_mat_10_0 = _mm512_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2); + __m512i iacc_mat_11_0 = _mm512_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2); + + __m512i iacc_mat_00_1 = _mm512_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2); + __m512i iacc_mat_01_1 = _mm512_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2); + __m512i iacc_mat_10_1 = _mm512_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2); + __m512i iacc_mat_11_1 = _mm512_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2); + + iacc_mat_00_0 = _mm512_madd_epi16(iacc_mat_00_0, scale_014589CD_0); + iacc_mat_01_0 = _mm512_madd_epi16(iacc_mat_01_0, scale_2367ABEF_0); + iacc_mat_10_0 = _mm512_madd_epi16(iacc_mat_10_0, scale_014589CD_0); + iacc_mat_11_0 = _mm512_madd_epi16(iacc_mat_11_0, scale_2367ABEF_0); + + iacc_mat_00_1 = _mm512_madd_epi16(iacc_mat_00_1, scale_014589CD_1); + iacc_mat_01_1 = _mm512_madd_epi16(iacc_mat_01_1, scale_2367ABEF_1); + iacc_mat_10_1 = _mm512_madd_epi16(iacc_mat_10_1, scale_014589CD_1); + iacc_mat_11_1 = _mm512_madd_epi16(iacc_mat_11_1, scale_2367ABEF_1); + + // Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step) + __m512i iacc_row_0_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_0, _mm512_shuffle_epi32(iacc_mat_01_0, (_MM_PERM_ENUM)78)); + __m512i iacc_row_1_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_0, (_MM_PERM_ENUM)78), iacc_mat_01_0); + __m512i iacc_row_2_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_0, _mm512_shuffle_epi32(iacc_mat_11_0, (_MM_PERM_ENUM)78)); + __m512i iacc_row_3_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10_0, (_MM_PERM_ENUM)78), iacc_mat_11_0); + __m512i iacc_row_0_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_1, _mm512_shuffle_epi32(iacc_mat_01_1, (_MM_PERM_ENUM)78)); + __m512i iacc_row_1_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_1, (_MM_PERM_ENUM)78), iacc_mat_01_1); + __m512i iacc_row_2_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_1, _mm512_shuffle_epi32(iacc_mat_11_1, (_MM_PERM_ENUM)78)); + __m512i iacc_row_3_1 = _mm512_mask_blend_epi32(0xCCCC,_mm512_shuffle_epi32(iacc_mat_10_1, (_MM_PERM_ENUM)78), iacc_mat_11_1); + + __m512i iacc_row_0 = _mm512_add_epi32(iacc_row_0_0, iacc_row_0_1); + __m512i iacc_row_1 = _mm512_add_epi32(iacc_row_1_0, iacc_row_1_1); + __m512i iacc_row_2 = _mm512_add_epi32(iacc_row_2_0, iacc_row_2_1); + __m512i iacc_row_3 = _mm512_add_epi32(iacc_row_3_0, iacc_row_3_1); + + // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes + const __m128 row_scale_f32_sse = _mm_load_ps(a_ptrs[rp][b].d); + const __m256 row_scale_f32_ymm = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse); + const __m512 row_scale_f32 = _mm512_insertf32x8(_mm512_castps256_ps512(row_scale_f32_ymm), row_scale_f32_ymm, 1); + + // Multiply with appropiate scales and accumulate (for both d and dmin) below + acc_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]); + acc_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]); + acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]); + acc_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[rp * 4 + 3]); + + __m512i iacc_row_min_0 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)0), mins_01); + __m512i iacc_row_min_1 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)85), mins_01); + __m512i iacc_row_min_2 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)170), mins_01); + __m512i iacc_row_min_3 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)255), mins_01); + + acc_min_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_0), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[rp * 4]); + acc_min_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_1), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[rp * 4 + 1]); + acc_min_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_2), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[rp * 4 + 2]); + acc_min_rows[rp * 4 + 3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_3), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[rp * 4 + 3]); + } + } + } + // Store the accumulated values + for (int i = 0; i < 16; i++) { + _mm512_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm512_sub_ps(acc_rows[i], acc_min_rows[i])); + } + } + } + + for (; y < nr / 4; y++) { + + const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb); + + // Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation + for (int64_t x = 0; x < anc / 8; x += 2) { + + const block_q4_Kx8 * b_ptr_0 = b_ptr_start + ((x) * b_nb); + const block_q4_Kx8 * b_ptr_1 = b_ptr_start + ((x + 1) * b_nb); + + // Master FP accumulators + __m512 acc_rows[4]; + for (int i = 0; i < 4; i++) { + acc_rows[i] = _mm512_setzero_ps(); + } + + __m512 acc_min_rows[4]; + for (int i = 0; i < 4; i++) { + acc_min_rows[i] = _mm512_setzero_ps(); + } + + // For super block + for (int64_t b = 0; b < nb; b++) { + // Scale values - Load the sixteen scale values from two block_q4_kx8 structures + const __m512 col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); + + // dmin values - Load the sixteen dmin values from two block_q4_kx8 structures + const __m512 col_dmin_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].dmin, b_ptr_1[b].dmin); + + // Loop to iterate over the eight sub blocks of a super block - two sub blocks are processed per iteration + for (int sb = 0; sb < QK_K / 64; sb++) { + + const __m256i rhs_raw_mat_0123_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + sb * 256)); + const __m256i rhs_raw_mat_4567_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 32 + sb * 256)); + const __m256i rhs_raw_mat_0123_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 64 + sb * 256)); + const __m256i rhs_raw_mat_4567_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 96 + sb * 256)); + const __m256i rhs_raw_mat_0123_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 128 + sb * 256)); + const __m256i rhs_raw_mat_4567_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 160 + sb * 256)); + const __m256i rhs_raw_mat_0123_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 192 + sb * 256)); + const __m256i rhs_raw_mat_4567_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_0[b].qs + 224 + sb * 256)); + + const __m256i rhs_raw_mat_89AB_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + sb * 256)); + const __m256i rhs_raw_mat_CDEF_0 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 32 + sb * 256)); + const __m256i rhs_raw_mat_89AB_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 64 + sb * 256)); + const __m256i rhs_raw_mat_CDEF_1 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 96 + sb * 256)); + const __m256i rhs_raw_mat_89AB_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 128 + sb * 256)); + const __m256i rhs_raw_mat_CDEF_2 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 160 + sb * 256)); + const __m256i rhs_raw_mat_89AB_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 192 + sb * 256)); + const __m256i rhs_raw_mat_CDEF_3 = _mm256_loadu_si256((const __m256i * )(b_ptr_1[b].qs + 224 + sb * 256)); + + const __m256i rhs_raw_mat_0145_0 = _mm256_blend_epi32(rhs_raw_mat_0123_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_0, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_0, requiredOrder), rhs_raw_mat_4567_0, 240); + const __m256i rhs_raw_mat_0145_1 = _mm256_blend_epi32(rhs_raw_mat_0123_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_1, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_1, requiredOrder), rhs_raw_mat_4567_1, 240); + const __m256i rhs_raw_mat_0145_2 = _mm256_blend_epi32(rhs_raw_mat_0123_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_2, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_2, requiredOrder), rhs_raw_mat_4567_2, 240); + const __m256i rhs_raw_mat_0145_3 = _mm256_blend_epi32(rhs_raw_mat_0123_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_4567_3, requiredOrder), 240); + const __m256i rhs_raw_mat_2367_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_0123_3, requiredOrder), rhs_raw_mat_4567_3, 240); + + const __m256i rhs_raw_mat_89CD_0 = _mm256_blend_epi32(rhs_raw_mat_89AB_0, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_0, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_0 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_0, requiredOrder), rhs_raw_mat_CDEF_0, 240); + const __m256i rhs_raw_mat_89CD_1 = _mm256_blend_epi32(rhs_raw_mat_89AB_1, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_1, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_1 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_1, requiredOrder), rhs_raw_mat_CDEF_1, 240); + const __m256i rhs_raw_mat_89CD_2 = _mm256_blend_epi32(rhs_raw_mat_89AB_2, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_2, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_2 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_2, requiredOrder), rhs_raw_mat_CDEF_2, 240); + const __m256i rhs_raw_mat_89CD_3 = _mm256_blend_epi32(rhs_raw_mat_89AB_3, _mm256_permutevar8x32_epi32(rhs_raw_mat_CDEF_3, requiredOrder), 240); + const __m256i rhs_raw_mat_ABEF_3 = _mm256_blend_epi32(_mm256_permutevar8x32_epi32(rhs_raw_mat_89AB_3, requiredOrder), rhs_raw_mat_CDEF_3, 240); + + const __m512i rhs_raw_mat_014589CD_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_0), rhs_raw_mat_89CD_0, 1); + const __m512i rhs_raw_mat_2367ABEF_0 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_0), rhs_raw_mat_ABEF_0, 1); + const __m512i rhs_raw_mat_014589CD_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_1), rhs_raw_mat_89CD_1, 1); + const __m512i rhs_raw_mat_2367ABEF_1 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_1), rhs_raw_mat_ABEF_1, 1); + + const __m512i rhs_raw_mat_014589CD_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_2), rhs_raw_mat_89CD_2, 1); + const __m512i rhs_raw_mat_2367ABEF_2 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_2), rhs_raw_mat_ABEF_2, 1); + const __m512i rhs_raw_mat_014589CD_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_0145_3), rhs_raw_mat_89CD_3, 1); + const __m512i rhs_raw_mat_2367ABEF_3 = _mm512_inserti32x8(_mm512_castsi256_si512(rhs_raw_mat_2367_3), rhs_raw_mat_ABEF_3, 1); + + //4-bit -> 8-bit + const __m512i rhs_mat_014589CD_00 = _mm512_and_si512(rhs_raw_mat_014589CD_0, m4bexpanded); //B00(0-7) B01(0-7) B04(0-7) B05(0-7) B08(0-7) B09(0-7) B0C(0-7) B0D(0-7) + const __m512i rhs_mat_2367ABEF_00 = _mm512_and_si512(rhs_raw_mat_2367ABEF_0, m4bexpanded); //B02(0-7) B03(0-7) B06(0-7) B07(0-7) B0A(0-7) B0B(0-7) B0E(0-7) B0F(0-7) + const __m512i rhs_mat_014589CD_01 = _mm512_and_si512(rhs_raw_mat_014589CD_1, m4bexpanded); //B00(8-15) B01(8-15) B04(8-15) B05(8-15) B08(8-15) B09(8-15) B0C(8-15) B0D(8-15) + const __m512i rhs_mat_2367ABEF_01 = _mm512_and_si512(rhs_raw_mat_2367ABEF_1, m4bexpanded); //B02(8-15) B03(8-15) B06(8-15) B07(8-15) B0A(8-15) B0B(8-15) B0E(8-15) B0F(8-15) + + const __m512i rhs_mat_014589CD_02 = _mm512_and_si512(rhs_raw_mat_014589CD_2, m4bexpanded); //B00(16-23) B01(16-23) B04(16-23) B05(16-23) B08(16-23) B09(16-23) B0C(16-23) B0D(16-23) + const __m512i rhs_mat_2367ABEF_02 = _mm512_and_si512(rhs_raw_mat_2367ABEF_2, m4bexpanded); //B02(16-23) B03(16-23) B06(16-23) B07(16-23) B0A(16-23) B0B(16-23) B0E(16-23) B0F(16-23) + const __m512i rhs_mat_014589CD_03 = _mm512_and_si512(rhs_raw_mat_014589CD_3, m4bexpanded); //B00(24-31) B01(24-31) B04(24-31) B05(24-31) B08(24-31) B09(24-31) B0C(24-31) B0D(24-31) + const __m512i rhs_mat_2367ABEF_03 = _mm512_and_si512(rhs_raw_mat_2367ABEF_3, m4bexpanded); //B02(24-31) B03(24-31) B06(24-31) B07(24-31) B0A(24-31) B0B(24-31) B0E(24-31) B0F(24-31) + + const __m512i rhs_mat_014589CD_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_0, 4), m4bexpanded); //B10(0-7) B11(0-7) B14(0-7) B15(0-7) B18(0-7) B19(0-7) B1C(0-7) B1D(0-7) + const __m512i rhs_mat_2367ABEF_10 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_0, 4), m4bexpanded); //B12(0-7) B13(0-7) B16(0-7) B17(0-7) B1A(0-7) B1B(0-7) B1E(0-7) B1F(0-7) + const __m512i rhs_mat_014589CD_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_1, 4), m4bexpanded); //B10(8-15) B11(8-15) B14(8-15) B15(8-15) B18(8-15) B19(8-15) B1C(8-15) B1D(8-15) + const __m512i rhs_mat_2367ABEF_11 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded); //B12(8-15) B13(8-15) B16(8-15) B17(8-15) B1A(8-15) B1B(8-15) B1E(8-15) B1F(8-15) + + const __m512i rhs_mat_014589CD_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_2, 4), m4bexpanded); //B10(16-23) B11(16-23) B14(16-23) B15(16-23) B18(16-23) B19(16-23) B1C(16-23) B1D(16-23) + const __m512i rhs_mat_2367ABEF_12 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_2, 4), m4bexpanded); //B12(16-23) B13(16-23) B16(16-23) B17(16-23) B1A(16-23) B1B(16-23) B1E(16-23) B1F(16-23) + const __m512i rhs_mat_014589CD_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_014589CD_3, 4), m4bexpanded); //B10(24-31) B11(24-31) B14(24-31) B15(24-31) B18(24-31) B19(24-31) B1C(24-31) B1D(24-31) + const __m512i rhs_mat_2367ABEF_13 = _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_3, 4), m4bexpanded); //B12(24-31) B13(24-31) B16(24-31) B17(24-31) B1A(24-31) B1B(24-31) B1E(24-31) B1F(24-31) + + // Shuffle pattern one - right side input + const __m512i rhs_mat_014589CD_00_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)136); //B00(0-3) B01(0-3) B00(0-3) B01(0-3) B04(0-3) B05(0-3) B04(0-3) B05(0-3) B08(0-3) B09(0-3) B08(0-3) B09(0-3) B0C(0-3) B0D(0-3) B0C(0-3) B0D(0-3) + const __m512i rhs_mat_2367ABEF_00_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)136); //B02(0-3) B03(0-3) B02(0-3) B03(0-3) B06(0-3) B07(0-3) B06(0-3) B07(0-3) B0A(0-3) B0B(0-3) B0A(0-3) B0B(0-3) B0E(0-3) B0F(0-3) B0E(0-3) B0F(0-3) + const __m512i rhs_mat_014589CD_01_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)136); //B00(8-11) B01(8-11) B00(8-11) B01(8-11) B04(8-11) B05(8-11) B04(8-11) B05(8-11) B08(8-11) B09(8-11) B08(8-11) B09(8-11) B0C(8-11) B0D(8-11) B0C(8-11) B0D(8-11) + const __m512i rhs_mat_2367ABEF_01_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)136); //B02(8-11) B03(8-11) B02(8-11) B03(8-11) B06(8-11) B07(8-11) B06(8-11) B07(8-11) B0A(8-11) B0B(8-11) B0A(8-11) B0B(8-11) B0E(8-11) B0F(8-11) B0E(8-11) B0F(8-11) + const __m512i rhs_mat_014589CD_02_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)136); //B00(16-19) B01(16-19) B00(16-19) B01(16-19) B04(16-19) B05(16-19) B04(16-19) B05(16-19) B08(16-19) B09(16-19) B08(16-19) B09(16-19) B0C(16-19) B0D(16-19) B0C(16-19) B0D(16-19) + const __m512i rhs_mat_2367ABEF_02_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)136); //B02(16-19) B03(16-19) B02(16-19) B03(16-19) B06(16-19) B07(16-19) B06(16-19) B07(16-19) B0A(16-19) B0B(16-19) B0A(16-19) B0B(16-19) B0E(16-19) B0F(16-19) B0E(16-19) B0F(16-19) + const __m512i rhs_mat_014589CD_03_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)136); //B00(24-27) B01(24-27) B00(24-27) B01(24-27) B04(24-27) B05(24-27) B04(24-27) B05(24-27) B08(24-27) B09(24-27) B08(24-27) B09(24-27) B0C(24-27) B0D(24-27) B0C(24-27) B0D(24-27) + const __m512i rhs_mat_2367ABEF_03_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)136); //B02(24-27) B03(24-27) B02(24-27) B03(24-27) B06(24-27) B07(24-27) B06(24-27) B07(24-27) B0A(24-27) B0B(24-27) B0A(24-27) B0B(24-27) B0E(24-27) B0F(24-27) B0E(24-27) B0F(24-27) + + const __m512i rhs_mat_014589CD_10_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)136); //B10(0-3) B11(0-3) B10(0-3) B11(0-3) B14(0-3) B15(0-3) B14(0-3) B15(0-3) B18(0-3) B19(0-3) B18(0-3) B19(0-3) B1C(0-3) B1D(0-3) B1C(0-3) B1D(0-3) + const __m512i rhs_mat_2367ABEF_10_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)136); //B12(0-3) B13(0-3) B12(0-3) B13(0-3) B16(0-3) B17(0-3) B16(0-3) B17(0-3) B1A(0-3) B1B(0-3) B1A(0-3) B1B(0-3) B1E(0-3) B1F(0-3) B1E(0-3) B1F(0-3) + const __m512i rhs_mat_014589CD_11_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)136); //B10(8-11) B11(8-11) B10(8-11) B11(8-11) B14(8-11) B15(8-11) B14(8-11) B15(8-11) B18(8-11) B19(8-11) B18(8-11) B19(8-11) B1C(8-11) B1D(8-11) B1C(8-11) B1D(8-11) + const __m512i rhs_mat_2367ABEF_11_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)136); //B12(8-11) B13(8-11) B12(8-11) B13(8-11) B16(8-11) B17(8-11) B16(8-11) B17(8-11) B1A(8-11) B1B(8-11) B1A(8-11) B1B(8-11) B1E(8-11) B1F(8-11) B1E(8-11) B1F(8-11) + const __m512i rhs_mat_014589CD_12_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)136); //B10(16-19) B11(16-19) B10(16-19) B11(16-19) B14(16-19) B15(16-19) B14(16-19) B15(16-19) B18(16-19) B19(16-19) B18(16-19) B19(16-19) B1C(16-19) B1D(16-19) B1C(16-19) B1D(16-19) + const __m512i rhs_mat_2367ABEF_12_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)136); //B12(16-19) B13(16-19) B12(16-19) B13(16-19) B16(16-19) B17(16-19) B16(16-19) B17(16-19) B1A(16-19) B1B(16-19) B1A(16-19) B1B(16-19) B1E(16-19) B1F(16-19) B1E(16-19) B1F(16-19) + const __m512i rhs_mat_014589CD_13_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)136); //B10(24-27) B11(24-27) B10(24-27) B11(24-27) B14(24-27) B15(24-27) B14(24-27) B15(24-27) B18(24-27) B19(24-27) B18(24-27) B19(24-27) B1C(24-27) B1D(24-27) B1C(24-27) B1D(24-27) + const __m512i rhs_mat_2367ABEF_13_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)136); //B12(24-27) B13(24-27) B12(24-27) B13(24-27) B16(24-27) B17(24-27) B16(24-27) B17(24-27) B1A(24-27) B1B(24-27) B1A(24-27) B1B(24-27) B1E(24-27) B1F(24-27) B1E(24-27) B1F(24-27) + + // Shuffle pattern two - right side input + const __m512i rhs_mat_014589CD_00_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_00, (_MM_PERM_ENUM)221); //B00(4-7) B01(4-7) B00(4-7) B01(4-7) B04(4-7) B05(4-7) B04(4-7) B05(4-7) B08(4-7) B09(4-7) B08(4-7) B09(4-7) B0C(4-7) B0D(4-7) B0C(4-7) B0D(4-7) + const __m512i rhs_mat_2367ABEF_00_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_00, (_MM_PERM_ENUM)221); //B02(4-7) B03(4-7) B02(4-7) B03(4-7) B06(4-7) B07(4-7) B06(4-7) B07(4-7) B0A(4-7) B0B(4-7) B0A(4-7) B0B(4-7) B0E(4-7) B0F(4-7) B0E(4-7) B0F(4-7) + const __m512i rhs_mat_014589CD_01_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_01, (_MM_PERM_ENUM)221); //B00(12-15) B01(12-15) B00(12-15) B01(12-15) B04(12-15) B05(12-15) B04(12-15) B05(12-15) B08(12-15) B09(12-15) B08(12-15) B09(12-15) B0C(12-15) B0D(12-15) B0C(12-15) B0D(12-15) + const __m512i rhs_mat_2367ABEF_01_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_01, (_MM_PERM_ENUM)221); //B02(12-15) B03(12-15) B02(12-15) B03(12-15) B06(12-15) B07(12-15) B06(12-15) B07(12-15) B0A(12-15) B0B(12-15) B0A(12-15) B0B(12-15) B0E(12-15) B0F(12-15) B0E(12-15) B0F(12-15) + const __m512i rhs_mat_014589CD_02_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_02, (_MM_PERM_ENUM)221); //B00(20-23) B01(20-23) B00(20-23) B01(20-23) B04(20-23) B05(20-23) B04(20-23) B05(20-23) B08(20-23) B09(20-23) B08(20-23) B09(20-23) B0C(20-23) B0D(20-23) B0C(20-23) B0D(20-23) + const __m512i rhs_mat_2367ABEF_02_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_02, (_MM_PERM_ENUM)221); //B02(20-23) B03(20-23) B02(20-23) B03(20-23) B06(20-23) B07(20-23) B06(20-23) B07(20-23) B0A(20-23) B0B(20-23) B0A(20-23) B0B(20-23) B0E(20-23) B0F(20-23) B0E(20-23) B0F(20-23) + const __m512i rhs_mat_014589CD_03_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_03, (_MM_PERM_ENUM)221); //B00(28-31) B01(28-31) B00(28-31) B01(28-31) B04(28-31) B05(28-31) B04(28-31) B05(28-31) B08(28-31) B09(28-31) B08(28-31) B09(28-31) B0C(28-31) B0D(28-31) B0C(28-31) 0BD(28-31) + const __m512i rhs_mat_2367ABEF_03_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_03, (_MM_PERM_ENUM)221); //B02(28-31) B03(28-31) B02(28-31) B03(28-31) B06(28-31) B07(28-31) B06(28-31) B07(28-31) B0A(28-31) B0B(28-31) B0A(28-31) B0B(28-31) B0E(28-31) B0F(28-31) B0E(28-31) B0F(28-31) + + const __m512i rhs_mat_014589CD_10_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_10, (_MM_PERM_ENUM)221); //B10(4-7) B11(4-7) B10(4-7) B11(4-7) B14(4-7) B15(4-7) B14(4-7) B15(4-7) B18(4-7) B19(4-7) B18(4-7) B19(4-7) B1C(4-7) B1D(4-7) B1C(4-7) B1D(4-7) + const __m512i rhs_mat_2367ABEF_10_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_10, (_MM_PERM_ENUM)221); //B12(4-7) B13(4-7) B12(4-7) B13(4-7) B16(4-7) B17(4-7) B16(4-7) B17(4-7) B1A(4-7) B1B(4-7) B1A(4-7) B1B(4-7) B1E(4-7) B1F(4-7) B1E(4-7) B1F(4-7) + const __m512i rhs_mat_014589CD_11_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_11, (_MM_PERM_ENUM)221); //B10(12-15) B11(12-15) B10(12-15) B11(12-15) B14(12-15) B15(12-15) B14(12-15) B15(12-15) B18(12-15) B19(12-15) B18(12-15) B19(12-15) B1C(12-15) B1D(12-15) B1C(12-15) B1D(12-15) + const __m512i rhs_mat_2367ABEF_11_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_11, (_MM_PERM_ENUM)221); //B12(12-15) B13(12-15) B12(12-15) B13(12-15) B16(12-15) B17(12-15) B16(12-15) B17(12-15) B1A(12-15) B1B(12-15) B1A(12-15) B1B(12-15) B1E(12-15) B1F(12-15) B1E(12-15) B1F(12-15) + const __m512i rhs_mat_014589CD_12_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_12, (_MM_PERM_ENUM)221); //B10(20-23) B11(20-23) B10(20-23) B11(20-23) B14(20-23) B15(20-23) B14(20-23) B15(20-23) B18(20-23) B19(20-23) B18(20-23) B19(20-23) B1C(20-23) B1D(20-23) B1C(20-23) B1D(20-23) + const __m512i rhs_mat_2367ABEF_12_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_12, (_MM_PERM_ENUM)221); //B12(20-23) B13(20-23) B12(20-23) B13(20-23) B16(20-23) B17(20-23) B16(20-23) B17(20-23) B1A(20-23) B1B(20-23) B1A(20-23) B1B(20-23) B1E(20-23) B1F(20-23) B1E(20-23) B1F(20-23) + const __m512i rhs_mat_014589CD_13_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_13, (_MM_PERM_ENUM)221); //B10(28-31) B11(28-31) B10(28-31) B11(28-31) B14(28-31) B15(28-31) B14(28-31) B15(28-31) B18(28-31) B19(28-31) B18(28-31) B19(28-31) B1C(28-31) B1D(28-31) B1C(28-31) B1D(28-31) + const __m512i rhs_mat_2367ABEF_13_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_13, (_MM_PERM_ENUM)221); //B12(28-31) B13(28-31) B12(28-31) B13(28-31) B16(28-31) B17(28-31) B16(28-31) B17(28-31) B1A(28-31) B1B(28-31) B1A(28-31) B1B(28-31) B1E(28-31) B1F(28-31) B1E(28-31) B1F(28-31) + + uint32_t utmp_00[4], utmp_01[4], utmp_10[4], utmp_11[4]; + + // Scales and Mins of corresponding sub blocks from different Q4_K structures are stored together + // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop + memcpy(utmp_00, b_ptr_0[b].scales + 24 * sb, 12); + utmp_00[3] = ((utmp_00[2] >> 4) & kmask2) | (((utmp_00[1] >> 6) & kmask3) << 4); + const uint32_t uaux_00 = utmp_00[1] & kmask1; + utmp_00[1] = (utmp_00[2] & kmask2) | (((utmp_00[0] >> 6) & kmask3) << 4); + utmp_00[2] = uaux_00; + utmp_00[0] &= kmask1; + + // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop + memcpy(utmp_01, b_ptr_0[b].scales + 12 + sb * 24, 12); + utmp_01[3] = ((utmp_01[2] >> 4) & kmask2) | (((utmp_01[1] >> 6) & kmask3) << 4); + const uint32_t uaux_01 = utmp_01[1] & kmask1; + utmp_01[1] = (utmp_01[2] & kmask2) | (((utmp_01[0] >> 6) & kmask3) << 4); + utmp_01[2] = uaux_01; + utmp_01[0] &= kmask1; + + // The below block is for eg to extract first sub block's scales and mins from different Q4_K structures for the sb loop + memcpy(utmp_10, b_ptr_1[b].scales + sb * 24, 12); + utmp_10[3] = ((utmp_10[2] >> 4) & kmask2) | (((utmp_10[1] >> 6) & kmask3) << 4); + const uint32_t uaux_10 = utmp_10[1] & kmask1; + utmp_10[1] = (utmp_10[2] & kmask2) | (((utmp_10[0] >> 6) & kmask3) << 4); + utmp_10[2] = uaux_10; + utmp_10[0] &= kmask1; + + // The below block is for eg to extract second sub block's scales and mins from different Q4_K structures for the sb loop + memcpy(utmp_11, b_ptr_1[b].scales + 12 + sb * 24, 12); + utmp_11[3] = ((utmp_11[2] >> 4) & kmask2) | (((utmp_11[1] >> 6) & kmask3) << 4); + const uint32_t uaux_11 = utmp_11[1] & kmask1; + utmp_11[1] = (utmp_11[2] & kmask2) | (((utmp_11[0] >> 6) & kmask3) << 4); + utmp_11[2] = uaux_11; + utmp_11[0] &= kmask1; + + // Scales of first sub block in the sb loop + const __m256i mins_and_scales_0 = _mm256_set_epi32(utmp_10[3], utmp_10[2], utmp_10[1], utmp_10[0], utmp_00[3], utmp_00[2], utmp_00[1], utmp_00[0]); + const __m512i scales_0 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_0, mins_and_scales_0)); + + // Scales of second sub block in the sb loop + const __m256i mins_and_scales_1 = _mm256_set_epi32(utmp_11[3], utmp_11[2], utmp_11[1], utmp_11[0], utmp_01[3], utmp_01[2], utmp_01[1], utmp_01[0]); + const __m512i scales_1 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(mins_and_scales_1, mins_and_scales_1)); + + // Mins of first and second sub block of Q4_K block are arranged side by side + const __m512i mins_01 = _mm512_cvtepu8_epi16(_mm256_unpacklo_epi8(_mm256_shuffle_epi32(mins_and_scales_0, 78), _mm256_shuffle_epi32(mins_and_scales_1, 78))); + + const __m512i scale_014589CD_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_0 = _mm512_shuffle_epi32(scales_0, (_MM_PERM_ENUM)238); + + const __m512i scale_014589CD_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)68); + const __m512i scale_2367ABEF_1 = _mm512_shuffle_epi32(scales_1, (_MM_PERM_ENUM)238); + + // Load the four block_q8_k quantized values interleaved with each other in chunks of eight bytes - A0,A1,A2,A3 + // Loaded as set of 128 bit vectors and repeated into a 256 bit vector + __m256i lhs_mat_ymm_0123_00 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 256 * sb))); + __m256i lhs_mat_ymm_01_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 0); + __m256i lhs_mat_ymm_23_00 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_00, lhs_mat_ymm_0123_00, 17); + __m256i lhs_mat_ymm_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 32 + 256 * sb))); + __m256i lhs_mat_ymm_01_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 0); + __m256i lhs_mat_ymm_23_01 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_01, lhs_mat_ymm_0123_01, 17); + __m256i lhs_mat_ymm_0123_02 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 64 + 256 * sb))); + __m256i lhs_mat_ymm_01_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 0); + __m256i lhs_mat_ymm_23_02 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_02, lhs_mat_ymm_0123_02, 17); + __m256i lhs_mat_ymm_0123_03 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 96 + 256 * sb))); + __m256i lhs_mat_ymm_01_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 0); + __m256i lhs_mat_ymm_23_03 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_03, lhs_mat_ymm_0123_03, 17); + __m256i lhs_mat_ymm_0123_10 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 128 + 256 * sb))); + __m256i lhs_mat_ymm_01_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 0); + __m256i lhs_mat_ymm_23_10 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_10, lhs_mat_ymm_0123_10, 17); + __m256i lhs_mat_ymm_0123_11 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 160 + 256 * sb))); + __m256i lhs_mat_ymm_01_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 0); + __m256i lhs_mat_ymm_23_11 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_11, lhs_mat_ymm_0123_11, 17); + __m256i lhs_mat_ymm_0123_12 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 192 + 256 * sb))); + __m256i lhs_mat_ymm_01_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 0); + __m256i lhs_mat_ymm_23_12 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_12, lhs_mat_ymm_0123_12, 17); + __m256i lhs_mat_ymm_0123_13 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].qs + 224 + 256 * sb))); + __m256i lhs_mat_ymm_01_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 0); + __m256i lhs_mat_ymm_23_13 = _mm256_permute2f128_si256(lhs_mat_ymm_0123_13, lhs_mat_ymm_0123_13, 17); + + //Loaded as set of 128 bit vectors and repeated and stored into a 256 bit vector before again repeating into a 512 bit vector + __m512i lhs_mat_01_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_00), lhs_mat_ymm_01_00, 1); + __m512i lhs_mat_23_00 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_00), lhs_mat_ymm_23_00, 1); + __m512i lhs_mat_01_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_01), lhs_mat_ymm_01_01, 1); + __m512i lhs_mat_23_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_01), lhs_mat_ymm_23_01, 1); + __m512i lhs_mat_01_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_02), lhs_mat_ymm_01_02, 1); + __m512i lhs_mat_23_02 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_02), lhs_mat_ymm_23_02, 1); + __m512i lhs_mat_01_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_03), lhs_mat_ymm_01_03, 1); + __m512i lhs_mat_23_03 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_03), lhs_mat_ymm_23_03, 1); + + __m512i lhs_mat_01_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_10), lhs_mat_ymm_01_10, 1); + __m512i lhs_mat_23_10 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_10), lhs_mat_ymm_23_10, 1); + __m512i lhs_mat_01_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_11), lhs_mat_ymm_01_11, 1); + __m512i lhs_mat_23_11 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_11), lhs_mat_ymm_23_11, 1); + __m512i lhs_mat_01_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_12), lhs_mat_ymm_01_12, 1); + __m512i lhs_mat_23_12 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_12), lhs_mat_ymm_23_12, 1); + __m512i lhs_mat_01_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_01_13), lhs_mat_ymm_01_13, 1); + __m512i lhs_mat_23_13 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_mat_ymm_23_13), lhs_mat_ymm_23_13, 1); + + // Bsums are loaded - four bsums are loaded (for two sub blocks) for the different Q8_K blocks + __m256i lhs_bsums_0123_01 = _mm256_loadu_si256((const __m256i * )((a_ptr[b].bsums + 16 * sb))); + __m256i lhs_bsums_hsum_ymm_0123_01 = _mm256_castsi128_si256(_mm_hadd_epi16(_mm256_castsi256_si128(lhs_bsums_0123_01), _mm256_extractf128_si256(lhs_bsums_0123_01, 1))); + lhs_bsums_hsum_ymm_0123_01 = _mm256_permute2x128_si256(lhs_bsums_hsum_ymm_0123_01, lhs_bsums_hsum_ymm_0123_01, 0); + __m512i lhs_bsums_hsum_0123_01 = _mm512_inserti32x8(_mm512_castsi256_si512(lhs_bsums_hsum_ymm_0123_01), lhs_bsums_hsum_ymm_0123_01, 1); + + // Shuffle pattern one - left side input + const __m512i lhs_mat_01_00_sp1 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)160); //A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) A00(0-3) A00(0-3) A01(0-3) A01(0-3) + const __m512i lhs_mat_23_00_sp1 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)160); //A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) A02(0-3) A02(0-3) A03(0-3) A03(0-3) + const __m512i lhs_mat_01_01_sp1 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)160); //A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) A00(8-11) A00(8-11) A01(8-11) A01(8-11) + const __m512i lhs_mat_23_01_sp1 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)160); //A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) A02(8-11) A02(8-11) A03(8-11) A03(8-11) + const __m512i lhs_mat_01_02_sp1 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)160); //A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) A00(16-19) A00(16-19) A01(16-19) A01(16-19) + const __m512i lhs_mat_23_02_sp1 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)160); //A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) A02(16-19) A02(16-19) A03(16-19) A03(16-19) + const __m512i lhs_mat_01_03_sp1 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)160); //A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) A00(24-27) A00(24-27) A01(24-27) A01(24-27) + const __m512i lhs_mat_23_03_sp1 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)160); //A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) A02(24-27) A02(24-27) A03(24-27) A03(24-27) + + const __m512i lhs_mat_01_10_sp1 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)160); //A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) A10(0-3) A10(0-3) A11(0-3) A11(0-3) + const __m512i lhs_mat_23_10_sp1 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)160); //A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) A12(0-3) A12(0-3) A13(0-3) A13(0-3) + const __m512i lhs_mat_01_11_sp1 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)160); //A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) A10(8-11) A10(8-11) A11(8-11) A11(8-11) + const __m512i lhs_mat_23_11_sp1 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)160); //A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) A12(8-11) A12(8-11) A13(8-11) A13(8-11) + const __m512i lhs_mat_01_12_sp1 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)160); //A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) A10(16-19) A10(16-19) A11(16-19) A11(16-19) + const __m512i lhs_mat_23_12_sp1 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)160); //A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) A12(16-19) A12(16-19) A13(16-19) A13(16-19) + const __m512i lhs_mat_01_13_sp1 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)160); //A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) A10(24-27) A10(24-27) A11(24-27) A11(24-27) + const __m512i lhs_mat_23_13_sp1 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)160); //A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) A12(24-27) A12(24-27) A13(24-27) A13(24-27) + + const __m512i lhs_mat_01_00_sp2 = _mm512_shuffle_epi32(lhs_mat_01_00, (_MM_PERM_ENUM)245); //A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) A00(4-7) A00(4-7) A01(4-7) A01(4-7) + const __m512i lhs_mat_23_00_sp2 = _mm512_shuffle_epi32(lhs_mat_23_00, (_MM_PERM_ENUM)245); //A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) A02(4-7) A02(4-7) A03(4-7) A03(4-7) + const __m512i lhs_mat_01_01_sp2 = _mm512_shuffle_epi32(lhs_mat_01_01, (_MM_PERM_ENUM)245); //A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) A00(12-15) A00(12-15) A01(12-15) A01(12-15) + const __m512i lhs_mat_23_01_sp2 = _mm512_shuffle_epi32(lhs_mat_23_01, (_MM_PERM_ENUM)245); //A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) A02(12-15) A02(12-15) A03(12-15) A03(12-15) + const __m512i lhs_mat_01_02_sp2 = _mm512_shuffle_epi32(lhs_mat_01_02, (_MM_PERM_ENUM)245); //A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) A00(20-23) A00(20-23) A01(20-23) A01(20-23) + const __m512i lhs_mat_23_02_sp2 = _mm512_shuffle_epi32(lhs_mat_23_02, (_MM_PERM_ENUM)245); //A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) A02(20-23) A02(20-23) A03(20-23) A03(20-23) + const __m512i lhs_mat_01_03_sp2 = _mm512_shuffle_epi32(lhs_mat_01_03, (_MM_PERM_ENUM)245); //A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) A00(28-31) A00(28-31) A01(28-31) A01(28-31) + const __m512i lhs_mat_23_03_sp2 = _mm512_shuffle_epi32(lhs_mat_23_03, (_MM_PERM_ENUM)245); //A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) A02(28-31) A02(28-31) A03(28-31) A03(28-31) + + const __m512i lhs_mat_01_10_sp2 = _mm512_shuffle_epi32(lhs_mat_01_10, (_MM_PERM_ENUM)245); //A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) A10(4-7) A10(4-7) A11(4-7) A11(4-7) + const __m512i lhs_mat_23_10_sp2 = _mm512_shuffle_epi32(lhs_mat_23_10, (_MM_PERM_ENUM)245); //A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) A12(4-7) A12(4-7) A13(4-7) A13(4-7) + const __m512i lhs_mat_01_11_sp2 = _mm512_shuffle_epi32(lhs_mat_01_11, (_MM_PERM_ENUM)245); //A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) A10(12-15) A10(12-15) A11(12-15) A11(12-15) + const __m512i lhs_mat_23_11_sp2 = _mm512_shuffle_epi32(lhs_mat_23_11, (_MM_PERM_ENUM)245); //A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) A12(12-15) A12(12-15) A13(12-15) A13(12-15) + const __m512i lhs_mat_01_12_sp2 = _mm512_shuffle_epi32(lhs_mat_01_12, (_MM_PERM_ENUM)245); //A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) A10(20-23) A10(20-23) A11(20-23) A11(20-23) + const __m512i lhs_mat_23_12_sp2 = _mm512_shuffle_epi32(lhs_mat_23_12, (_MM_PERM_ENUM)245); //A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) A12(20-23) A12(20-23) A13(20-23) A13(20-23) + const __m512i lhs_mat_01_13_sp2 = _mm512_shuffle_epi32(lhs_mat_01_13, (_MM_PERM_ENUM)245); //A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) A10(28-31) A10(28-31) A11(28-31) A11(28-31) + const __m512i lhs_mat_23_13_sp2 = _mm512_shuffle_epi32(lhs_mat_23_13, (_MM_PERM_ENUM)245); //A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) A12(28-31) A12(28-31) A13(28-31) A13(28-31) + + // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane + __m512i iacc_mat_00_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_01_00_sp1)); + __m512i iacc_mat_01_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_01_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_01_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_01_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_01_00_sp1)); + __m512i iacc_mat_10_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp1, lhs_mat_23_00_sp1)); + __m512i iacc_mat_11_0_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp1, lhs_mat_23_03_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp1, lhs_mat_23_02_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp1, lhs_mat_23_01_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp1, lhs_mat_23_00_sp1)); + __m512i iacc_mat_00_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_01_10_sp1)); + __m512i iacc_mat_01_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_01_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_01_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_01_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_01_10_sp1)); + __m512i iacc_mat_10_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp1, lhs_mat_23_10_sp1)); + __m512i iacc_mat_11_1_sp1 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp1, lhs_mat_23_13_sp1), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp1, lhs_mat_23_12_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp1, lhs_mat_23_11_sp1)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp1, lhs_mat_23_10_sp1)); + + __m512i iacc_mat_00_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_01_00_sp2)); + __m512i iacc_mat_01_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_01_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_01_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_01_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_01_00_sp2)); + __m512i iacc_mat_10_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_00_sp2, lhs_mat_23_00_sp2)); + __m512i iacc_mat_11_0_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_03_sp2, lhs_mat_23_03_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_02_sp2, lhs_mat_23_02_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_01_sp2, lhs_mat_23_01_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_00_sp2, lhs_mat_23_00_sp2)); + __m512i iacc_mat_00_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_01_10_sp2)); + __m512i iacc_mat_01_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_01_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_01_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_01_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_01_10_sp2)); + __m512i iacc_mat_10_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_014589CD_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_014589CD_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_014589CD_10_sp2, lhs_mat_23_10_sp2)); + __m512i iacc_mat_11_1_sp2 = _mm512_add_epi16(_mm512_add_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(rhs_mat_2367ABEF_13_sp2, lhs_mat_23_13_sp2), _mm512_maddubs_epi16(rhs_mat_2367ABEF_12_sp2, lhs_mat_23_12_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_11_sp2, lhs_mat_23_11_sp2)), _mm512_maddubs_epi16(rhs_mat_2367ABEF_10_sp2, lhs_mat_23_10_sp2)); + + // Output of both shuffle patterns are added in order to sum dot product outputs of all 32 values in block + __m512i iacc_mat_00_0 = _mm512_add_epi16(iacc_mat_00_0_sp1, iacc_mat_00_0_sp2); + __m512i iacc_mat_01_0 = _mm512_add_epi16(iacc_mat_01_0_sp1, iacc_mat_01_0_sp2); + __m512i iacc_mat_10_0 = _mm512_add_epi16(iacc_mat_10_0_sp1, iacc_mat_10_0_sp2); + __m512i iacc_mat_11_0 = _mm512_add_epi16(iacc_mat_11_0_sp1, iacc_mat_11_0_sp2); + + __m512i iacc_mat_00_1 = _mm512_add_epi16(iacc_mat_00_1_sp1, iacc_mat_00_1_sp2); + __m512i iacc_mat_01_1 = _mm512_add_epi16(iacc_mat_01_1_sp1, iacc_mat_01_1_sp2); + __m512i iacc_mat_10_1 = _mm512_add_epi16(iacc_mat_10_1_sp1, iacc_mat_10_1_sp2); + __m512i iacc_mat_11_1 = _mm512_add_epi16(iacc_mat_11_1_sp1, iacc_mat_11_1_sp2); + + iacc_mat_00_0 = _mm512_madd_epi16(iacc_mat_00_0, scale_014589CD_0); + iacc_mat_01_0 = _mm512_madd_epi16(iacc_mat_01_0, scale_2367ABEF_0); + iacc_mat_10_0 = _mm512_madd_epi16(iacc_mat_10_0, scale_014589CD_0); + iacc_mat_11_0 = _mm512_madd_epi16(iacc_mat_11_0, scale_2367ABEF_0); + + iacc_mat_00_1 = _mm512_madd_epi16(iacc_mat_00_1, scale_014589CD_1); + iacc_mat_01_1 = _mm512_madd_epi16(iacc_mat_01_1, scale_2367ABEF_1); + iacc_mat_10_1 = _mm512_madd_epi16(iacc_mat_10_1, scale_014589CD_1); + iacc_mat_11_1 = _mm512_madd_epi16(iacc_mat_11_1, scale_2367ABEF_1); + + // Straighten out to make 4 row vectors (4 for each sub block which are accumulated together in the next step) + __m512i iacc_row_0_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_0, _mm512_shuffle_epi32(iacc_mat_01_0, (_MM_PERM_ENUM)78)); + __m512i iacc_row_1_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_0, (_MM_PERM_ENUM)78), iacc_mat_01_0); + __m512i iacc_row_2_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_0, _mm512_shuffle_epi32(iacc_mat_11_0, (_MM_PERM_ENUM)78)); + __m512i iacc_row_3_0 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10_0, (_MM_PERM_ENUM)78), iacc_mat_11_0); + __m512i iacc_row_0_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00_1, _mm512_shuffle_epi32(iacc_mat_01_1, (_MM_PERM_ENUM)78)); + __m512i iacc_row_1_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00_1, (_MM_PERM_ENUM)78), iacc_mat_01_1); + __m512i iacc_row_2_1 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10_1, _mm512_shuffle_epi32(iacc_mat_11_1, (_MM_PERM_ENUM)78)); + __m512i iacc_row_3_1 = _mm512_mask_blend_epi32(0xCCCC,_mm512_shuffle_epi32(iacc_mat_10_1, (_MM_PERM_ENUM)78), iacc_mat_11_1); + + __m512i iacc_row_0 = _mm512_add_epi32(iacc_row_0_0, iacc_row_0_1); + __m512i iacc_row_1 = _mm512_add_epi32(iacc_row_1_0, iacc_row_1_1); + __m512i iacc_row_2 = _mm512_add_epi32(iacc_row_2_0, iacc_row_2_1); + __m512i iacc_row_3 = _mm512_add_epi32(iacc_row_3_0, iacc_row_3_1); + + // Load the scale(d) values for all the 4 Q8_k blocks and repeat it across lanes + const __m128 row_scale_f32_sse = _mm_load_ps(a_ptr[b].d); + const __m256 row_scale_f32_ymm = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse); + const __m512 row_scale_f32 = _mm512_insertf32x8(_mm512_castps256_ps512(row_scale_f32_ymm), row_scale_f32_ymm, 1); + + // Multiply with appropiate scales and accumulate (for both d and dmin) below + acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]); + acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]); + acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]); + acc_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_3), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_rows[3]); + + __m512i iacc_row_min_0 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)0), mins_01); + __m512i iacc_row_min_1 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)85), mins_01); + __m512i iacc_row_min_2 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)170), mins_01); + __m512i iacc_row_min_3 = _mm512_madd_epi16(_mm512_shuffle_epi32(lhs_bsums_hsum_0123_01, (_MM_PERM_ENUM)255), mins_01); + + acc_min_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_0), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_min_rows[0]); + acc_min_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_1), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_min_rows[1]); + acc_min_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_2), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_min_rows[2]); + acc_min_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_3), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[3]); + } + } + // Store accumlated values + for (int i = 0; i < 4; i++) { + _mm512_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm512_sub_ps(acc_rows[i], acc_min_rows[i])); + } + } + } + if (anc != nc) { + xstart = anc/8; + y = 0; + } +#endif //AVX512F + // Take group of four block_q8_Kx4 structures at each pass of the loop and perform dot product operation for (; y < anr / 4; y += 4) { @@ -4099,7 +2554,7 @@ static void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, c } // Take group of eight block_q4_kx8 structures at each pass of the loop and perform dot product operation - for (int64_t x = 0; x < nc / 8; x++) { + for (int64_t x = xstart; x < nc / 8; x++) { const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb); @@ -4433,7 +2888,7 @@ static void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, c const block_q8_Kx4 * a_ptr = a_ptr_start + (y * nb); - for (int64_t x = 0; x < nc / 8; x++) { + for (int64_t x = xstart; x < nc / 8; x++) { const block_q4_Kx8 * b_ptr = b_ptr_start + (x * b_nb); @@ -4827,899 +3282,3 @@ static void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, c } #endif } - -static void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - -#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) - if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { - const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl); - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); - - float32x4_t sumf[4]; - for (int m = 0; m < 4; m++) { - sumf[m] = vdupq_n_f32(0); - } - - for (int l = 0; l < nb; l++) { - float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d)); - float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d)); - - int32x4_t sumi_0 = vdupq_n_s32(0); - int32x4_t sumi_1 = vdupq_n_s32(0); - int32x4_t sumi_2 = vdupq_n_s32(0); - int32x4_t sumi_3 = vdupq_n_s32(0); - - for (int k = 0; k < 4; k++) { - int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0); - int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64); - - uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k); - int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4); - int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF); - - sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0); - sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1); - sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2); - sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3); - sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0); - sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1); - sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2); - sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3); - } - - sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0)); - sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1)); - sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2)); - sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3)); - } - - for (int m = 0; m < 4; m++) { - vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]); - } - } - } - return; - } -#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) - { - float sumf[4][4]; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; - const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; - sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])); - } - sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } - } -} - -static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) { - block_q4_0x4 out; - - for (int i = 0; i < 4; i++) { - out.d[i] = in[i].d; - } - - const int end = QK4_0 * 2 / blck_size_interleave; - - if (blck_size_interleave == 8) { - const uint64_t xor_mask = 0x8888888888888888ULL; - for (int i = 0; i < end; ++i) { - int src_id = i % 4; - int src_offset = (i / 4) * blck_size_interleave; - int dst_offset = i * blck_size_interleave; - - uint64_t elems; - // Using memcpy to avoid unaligned memory accesses - memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); - elems ^= xor_mask; - memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); - } - } else if (blck_size_interleave == 4) { - const uint32_t xor_mask = 0x88888888; - for (int i = 0; i < end; ++i) { - int src_id = i % 4; - int src_offset = (i / 4) * blck_size_interleave; - int dst_offset = i * blck_size_interleave; - - uint32_t elems; - memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint32_t)); - elems ^= xor_mask; - memcpy(&out.qs[dst_offset], &elems, sizeof(uint32_t)); - } - } else { - GGML_ASSERT(false); - } - - return out; -} - -// interleave 8 block_q4_0s in blocks of blck_size_interleave -// returns an interleaved block_q4_0x8 -// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks -// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave -static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave) { - block_q4_0x8 out; - - for (int i = 0; i < 8; i++) { - out.d[i] = in[i].d; - } - - const int end = QK4_0 * 4 / blck_size_interleave; - const uint64_t xor_mask = 0x8888888888888888ULL; - - for (int i = 0; i < end; ++i) { - int src_id = i % 8; - int src_offset = (i / 8) * blck_size_interleave; - int dst_offset = i * blck_size_interleave; - - uint64_t elems; - memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); - elems ^= xor_mask; - memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); - } - - return out; -} - -static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_interleave) { - block_q4_Kx8 out; - //Delta(scale) and dmin values of the eight Q4_K structures are copied onto the output interleaved structure - for (int i = 0; i < 8; i++) { - out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; - } - - for (int i = 0; i < 8; i++) { - out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; - } - - const int end = QK_K * 4 / blck_size_interleave; - - // Interleave Q4_K quants by taking 8 bytes at a time - for (int i = 0; i < end; ++i) { - int src_id = i % 8; - int src_offset = (i / 8) * blck_size_interleave; - int dst_offset = i * blck_size_interleave; - - uint64_t elems; - memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); - memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); - } - - // The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K - // Currently the Q4_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value) - // The output Q4_Kx8 structure has 96 bytes - // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q4_K structure - // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q4_K structures - uint8_t s[8], m[8]; - - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 8; j++) { - s[j] = in[j].scales[i] & 63; - m[j] = in[j].scales[i + 4] & 63; - } - - out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2); - out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2); - out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2); - out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2); - out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2); - out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2); - out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2); - out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2); - out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4); - out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4); - out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4); - out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4); - - } - - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 8; j++) { - s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15); - m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4); - } - - out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2); - out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2); - out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2); - out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2); - out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2); - out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2); - out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2); - out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2); - out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4); - out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4); - out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4); - out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4); - - } - - return out; -} - -static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_0); - GGML_ASSERT(interleave_block == 4 || interleave_block == 8); - constexpr int nrows_interleaved = 4; - - block_q4_0x4 * dst = (block_q4_0x4 *)t->data; - const block_q4_0 * src = (const block_q4_0 *)data; - block_q4_0 dst_tmp[4]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK4_0; - - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); - - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { - return -1; - } - - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++) { - dst_tmp[i] = src[x + i * nblocks]; - } - *dst++ = make_block_q4_0x4(dst_tmp, interleave_block); - } - src += nrows_interleaved * nblocks; - } - return 0; - - GGML_UNUSED(data_size); -} -static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_K); - GGML_ASSERT(interleave_block == 8); - constexpr int nrows_interleaved = 8; - - block_q4_Kx8 * dst = (block_q4_Kx8*)t->data; - const block_q4_K * src = (const block_q4_K*) data; - block_q4_K dst_tmp[8]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK_K; - - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K)); - - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { - return -1; - } - - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++ ) { - dst_tmp[i] = src[x + i * nblocks]; - } - *dst++ = make_block_q4_Kx8(dst_tmp, interleave_block); - } - src += nrows_interleaved * nblocks; - } - return 0; - - GGML_UNUSED(data_size); -} - -static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_0); - GGML_ASSERT(interleave_block == 8); - constexpr int nrows_interleaved = 8; - - block_q4_0x8 * dst = (block_q4_0x8*)t->data; - const block_q4_0 * src = (const block_q4_0*) data; - block_q4_0 dst_tmp[8]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK4_0; - - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); - - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { - return -1; - } - - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++ ) { - dst_tmp[i] = src[x + i * nblocks]; - } - *dst++ = make_block_q4_0x8(dst_tmp, interleave_block); - } - src += nrows_interleaved * nblocks; - } - return 0; - - GGML_UNUSED(data_size); -} - -static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) { - block_iq4_nlx4 out; - - for (int i = 0; i < 4; i++) { - out.d[i] = in[i].d; - } - - const int end = QK4_NL * 2 / blck_size_interleave; - - // TODO: this branch seems wrong - //if (blck_size_interleave == 8) { - // for (int i = 0; i < end; ++i) { - // int src_id = i % 4; - // int src_offset = (i / 4) * blck_size_interleave; - // int dst_offset = i * blck_size_interleave; - - // // Using memcpy to avoid unaligned memory accesses - // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); - // } - //} else - if (blck_size_interleave == 4) { - for (int i = 0; i < end; ++i) { - int src_id = i % 4; - int src_offset = (i / 4) * blck_size_interleave; - int dst_offset = i * blck_size_interleave; - - memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t)); - } - } else { - GGML_ASSERT(false); - } - - return out; -} - -static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL); - //GGML_ASSERT(interleave_block == 4 || interleave_block == 8); - GGML_ASSERT(interleave_block == 4); - - block_iq4_nlx4 * dst = (block_iq4_nlx4 *)t->data; - const block_iq4_nl * src = (const block_iq4_nl *)data; - block_iq4_nl dst_tmp[4]; - int nrow = ggml_nrows(t); - int nrows_interleaved = 4; - int nblocks = t->ne[0] / QK4_0; - - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl)); - - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { - return -1; - } - - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++) { - dst_tmp[i] = src[x + i * nblocks]; - } - *dst++ = make_block_iq4_nlx4(dst_tmp, interleave_block); - } - src += nrows_interleaved * nblocks; - } - return 0; - - GGML_UNUSED(data_size); -} - -namespace ggml::cpu::aarch64 { -// repack -template -int repack(struct ggml_tensor *, const void *, size_t); - -// TODO: generalise. -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_0_to_q4_0_4_bl(t, 4, data, data_size); -} - -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_0_to_q4_0_4_bl(t, 8, data, data_size); -} - -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size); -} - -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size); -} - -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size); -} - -// TODO: needs to be revisited -//template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { -// return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size); -//} - -// gemv -template -void gemv(int, float *, size_t, const void *, const void *, int, int); - -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); -} - -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); -} - -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); -} - -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); -} - -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); -} - -// gemm -template -void gemm(int, float *, size_t, const void *, const void *, int, int); - -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); -} - -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); -} - -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); -} - -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); -} - -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); -} - -class tensor_traits_base : public ggml::cpu::tensor_traits { - public: - virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0; -}; - -template class tensor_traits : public tensor_traits_base { - - bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { - // not realy a GGML_TYPE_Q8_0 but same size. - switch (op->op) { - case GGML_OP_MUL_MAT: - size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1])); - return true; - case GGML_OP_MUL_MAT_ID: - size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1])); - size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc. - size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2]; - return true; - default: - // GGML_ABORT("fatal error"); - break; - } - return false; - } - - bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override { - switch (op->op) { - case GGML_OP_MUL_MAT: - forward_mul_mat(params, op); - return true; - case GGML_OP_MUL_MAT_ID: - forward_mul_mat_id(params, op); - return true; - default: - // GGML_ABORT("fatal error"); - break; - } - return false; - } - - void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) { - const ggml_tensor * src0 = op->src[0]; - const ggml_tensor * src1 = op->src[1]; - ggml_tensor * dst = op; - - GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - GGML_ASSERT(ne0 == ne01); - GGML_ASSERT(ne1 == ne11); - GGML_ASSERT(ne2 == ne12); - GGML_ASSERT(ne3 == ne13); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - GGML_ASSERT(ggml_n_dims(op->src[0]) == 2); - // GGML_ASSERT(ggml_n_dims(op->src[1]) == 2); - - char * wdata = static_cast(params->wdata); - const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10); - - assert(params->wsize >= nbw1 * ne11); - - const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float; - - int64_t i11_processed = 0; - for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { - ggml_quantize_mat_t((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10); - } - - i11_processed = ne11 - ne11 % 4; - for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) { - from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10); - } - - ggml_barrier(params->threadpool); - - const void * src1_wdata = params->wdata; - const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10); - int64_t src0_start = (ith * ne01) / nth; - int64_t src0_end = ((ith + 1) * ne01) / nth; - src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start; - src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end; - if (src0_start >= src0_end) { - return; - } - - // If there are more than three rows in src1, use gemm; otherwise, use gemv. - if (ne11 > 3) { - gemm(ne00, - (float *) ((char *) dst->data) + src0_start, ne01, - (const char *) src0->data + src0_start * nb01, - (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start); - } - for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) { - gemv(ne00, - (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01, - (const char *) src0->data + src0_start * nb01, - (const char *) src1_wdata + (src1_col_stride * iter), 1, - src0_end - src0_start); - } - } - - void forward_mul_mat_id(ggml_compute_params * params, ggml_tensor * op) { - const ggml_tensor * src0 = op->src[0]; - const ggml_tensor * src1 = op->src[1]; - const ggml_tensor * ids = op->src[2]; - ggml_tensor * dst = op; - - GGML_TENSOR_BINARY_OP_LOCALS - - const int ith = params->ith; - const int nth = params->nth; - - const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float; - - // we don't support permuted src0 or src1 - GGML_ASSERT(nb00 == ggml_type_size(src0->type)); - GGML_ASSERT(nb10 == ggml_type_size(src1->type)); - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - GGML_ASSERT(ne03 == 1); - GGML_ASSERT(ne13 == 1); - GGML_ASSERT(ne3 == 1); - - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - // row groups - const int n_ids = ids->ne[0]; // n_expert_used - const int n_as = ne02; // n_expert - - const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10); - const size_t nbw2 = nbw1*ne11; - const size_t nbw3 = nbw2*ne12; - - struct mmid_row_mapping { - int32_t i1; - int32_t i2; - }; - - GGML_ASSERT(params->wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) + - n_as * ne12 * sizeof(mmid_row_mapping))); - - auto * wdata = (char *) params->wdata; - auto * wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t)); - auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] - - struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12] - - // src1: float32 => param type - for (int64_t i12 = 0; i12 < ne12; ++i12) { - for (int64_t i11 = ith; i11 < ne11; i11 += nth) { - from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11), - (void *) (wdata + i12 * nbw2 + i11 * nbw1), - ne10); - } - } - -#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ne12 + (i1)] - - if (ith == 0) { - // initialize matrix_row_counts - memset(matrix_row_counts, 0, n_as * sizeof(int64_t)); - - // group rows by src0 matrix - for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { - for (int32_t id = 0; id < n_ids; ++id) { - const int32_t i02 = - *(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); - - GGML_ASSERT(i02 >= 0 && i02 < n_as); - - MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 }; - matrix_row_counts[i02] += 1; - } - } - } - - ggml_barrier(params->threadpool); - - // compute each matrix multiplication in sequence - for (int cur_a = 0; cur_a < n_as; ++cur_a) { - const int64_t cne1 = matrix_row_counts[cur_a]; - - if (cne1 == 0) { - continue; - } - - const auto * src0_cur = (const char *) src0->data + cur_a*nb02; - - //const int64_t nr0 = ne01; // src0 rows - const int64_t nr1 = cne1; // src1 rows - - int64_t src0_cur_start = (ith * ne01) / nth; - int64_t src0_cur_end = ((ith + 1) * ne01) / nth; - - src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start; - src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end; - - if (src0_cur_start >= src0_cur_end) { - return; - } - - for (int ir1 = 0; ir1 < nr1; ir1++) { - struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1); - - const int id = row_mapping.i1; // selected expert index - - const int64_t i11 = id % ne11; - const int64_t i12 = row_mapping.i2; // row index in src1 - - const int64_t i1 = id; // selected expert index - const int64_t i2 = i12; // row - - const auto * src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2); - - gemv(ne00, - (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01, - src0_cur + src0_cur_start * nb01, - src1_col, 1, src0_cur_end - src0_cur_start); - } - } -#undef MMID_MATRIX_ROW - } - - int repack(struct ggml_tensor * t, const void * data, size_t data_size) override { - GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type), - (int) NB_COLS, (int) INTER_SIZE); - return ggml::cpu::aarch64::repack(t, data, data_size); - } -}; - -// instance for Q4 -static const tensor_traits q4_0_4x4_q8_0; -static const tensor_traits q4_0_4x8_q8_0; -static const tensor_traits q4_0_8x8_q8_0; -static const tensor_traits q4_K_8x8_q8_K; - -// instance for IQ4 -static const tensor_traits iq4_nl_4x4_q8_0; - -} // namespace ggml::cpu::aarch64 - -static const ggml::cpu::tensor_traits * ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur) { - if (cur->type == GGML_TYPE_Q4_0) { - if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) { - if (cur->ne[1] % 8 == 0) { - return &ggml::cpu::aarch64::q4_0_8x8_q8_0; - } - } - if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { - if (cur->ne[1] % 4 == 0) { - return &ggml::cpu::aarch64::q4_0_4x8_q8_0; - } - } - if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { - if (cur->ne[1] % 4 == 0) { - return &ggml::cpu::aarch64::q4_0_4x4_q8_0; - } - } - } else if (cur->type == GGML_TYPE_Q4_K) { - if (ggml_cpu_has_avx2()) { - if (cur->ne[1] % 8 == 0) { - return &ggml::cpu::aarch64::q4_K_8x8_q8_K; - } - } - } else if (cur->type == GGML_TYPE_IQ4_NL) { - if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { - if (cur->ne[1] % 4 == 0) { - return &ggml::cpu::aarch64::iq4_nl_4x4_q8_0; - } - } - } - - return nullptr; -} - -static enum ggml_status ggml_backend_cpu_aarch64_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { - tensor->extra = (void *) const_cast(ggml_aarch64_get_optimal_repack_type(tensor)); - - GGML_UNUSED(buffer); - return GGML_STATUS_SUCCESS; -} - -static void ggml_backend_cpu_aarch64_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, - const void * data, size_t offset, size_t size) { - GGML_ASSERT(offset == 0); - GGML_ASSERT(size == ggml_nbytes(tensor)); - - auto tensor_traits = (ggml::cpu::aarch64::tensor_traits_base *) tensor->extra; - auto OK = tensor_traits->repack(tensor, data, size); - - GGML_ASSERT(OK == 0); - GGML_UNUSED(buffer); -} - -static const char * ggml_backend_cpu_aarch64_buffer_type_get_name(ggml_backend_buffer_type_t buft) { - return "CPU_AARCH64"; - - GGML_UNUSED(buft); -} - -static ggml_backend_buffer_t ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); - - if (buffer == nullptr) { - return nullptr; - } - - buffer->buft = buft; - buffer->iface.init_tensor = ggml_backend_cpu_aarch64_buffer_init_tensor; - buffer->iface.set_tensor = ggml_backend_cpu_aarch64_buffer_set_tensor; - buffer->iface.get_tensor = nullptr; - buffer->iface.cpy_tensor = nullptr; - return buffer; -} - -static size_t ggml_backend_cpu_aarch64_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - return TENSOR_ALIGNMENT; - - GGML_UNUSED(buft); -} - -namespace ggml::cpu::aarch64 { -class extra_buffer_type : ggml::cpu::extra_buffer_type { - bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { - if ( op->op == GGML_OP_MUL_MAT && - op->src[0]->buffer && - (ggml_n_dims(op->src[0]) == 2) && - op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type() && - ggml_aarch64_get_optimal_repack_type(op->src[0]) - ) { - if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { - return false; - } - if (op->src[1]->type == GGML_TYPE_F32) { - return true; - } - //if (op->src[1]->type == GGML_TYPE_Q8_0) { - // return true; - //} - // may be possible if Q8_0 packed... - } else if (op->op == GGML_OP_MUL_MAT_ID - && op->src[0]->buffer - && (ggml_n_dims(op->src[0]) == 3) - && op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type() - && ggml_aarch64_get_optimal_repack_type(op->src[0]) - ) { - if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { - return false; - } - if (op->src[1]->type == GGML_TYPE_F32) { - return true; - } - //if (op->src[1]->type == GGML_TYPE_Q8_0) { - // return true; - //} - } - return false; - } - - ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { - if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) { - if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_aarch64_buffer_type()) { - return (ggml::cpu::tensor_traits *) op->src[0]->extra; - } - } - return nullptr; - } -}; -} // namespace ggml::cpu::aarch64 - -ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void) { - static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_aarch64 = { - /* .iface = */ { - /* .get_name = */ ggml_backend_cpu_aarch64_buffer_type_get_name, - /* .alloc_buffer = */ ggml_backend_cpu_aarch64_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_cpu_aarch64_buffer_type_get_alignment, - /* .get_max_size = */ nullptr, // defaults to SIZE_MAX - /* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes - /* .is_host = */ nullptr, - }, - /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), - /* .context = */ new ggml::cpu::aarch64::extra_buffer_type(), - }; - - return &ggml_backend_cpu_buffer_type_aarch64; -} diff --git a/ggml/src/ggml-cpu/common.h b/ggml/src/ggml-cpu/common.h index 3df01c1ed..5624176cc 100644 --- a/ggml/src/ggml-cpu/common.h +++ b/ggml/src/ggml-cpu/common.h @@ -1,7 +1,7 @@ #pragma once #include "ggml.h" -#include "ggml-cpu-traits.h" +#include "traits.h" #include "ggml-cpu-impl.h" #include "ggml-impl.h" diff --git a/ggml/src/ggml-cpu/ggml-cpu-aarch64.h b/ggml/src/ggml-cpu/ggml-cpu-aarch64.h deleted file mode 100644 index 6e84c826b..000000000 --- a/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +++ /dev/null @@ -1,8 +0,0 @@ -#pragma once - -#include "ggml-cpu-traits.h" -#include "ggml.h" - -// GGML internal header - -ggml_backend_buffer_type_t ggml_backend_cpu_aarch64_buffer_type(void); diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h index e4af07635..73a8f9398 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-impl.h +++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -320,21 +320,17 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) #ifdef __wasm_simd128__ #include -#else +#endif + #ifdef __POWER9_VECTOR__ #include -#else +#endif + #if defined(_MSC_VER) || defined(__MINGW32__) #include -#else -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__) -#if !defined(__riscv) +#elif defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__) #include #endif -#endif -#endif -#endif -#endif #ifdef __riscv_v_intrinsic #include @@ -375,7 +371,7 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) #define vec_xor(a, b) ((a) ^ (b)) // Vector XOR #endif -typedef signed char char8x16_t __attribute__((vector_size(16))); +typedef signed char char8x16_t __attribute__((vector_size(16))); typedef unsigned char uchar8x16_t __attribute__((vector_size(16))); typedef int8_t int8x16_t __attribute__((vector_size(16))); @@ -386,10 +382,10 @@ typedef uint8_t uint8x16_t __attribute__((vector_size(16))); typedef uint16_t uint16x8_t __attribute__((vector_size(16))); typedef uint32_t uint32x4_t __attribute__((vector_size(16))); -typedef float float32x4_t __attribute__((vector_size(16))); -typedef double double64x2_t __attribute((vector_size(16))); +typedef float float32x4_t __attribute__((vector_size(16))); +typedef double double64x2_t __attribute__((vector_size(16))); -typedef signed long long long64x2_t __attribute((vector_size(16))); +typedef signed long long long64x2_t __attribute__((vector_size(16))); typedef unsigned long long ulong64x2_t __attribute__((vector_size(16))); typedef struct ggml_uint8x16x2_t { @@ -507,6 +503,9 @@ static __m256 __lasx_xvreplfr2vr_s(const float val) { // TODO: move to ggml-threading void ggml_barrier(struct ggml_threadpool * tp); +void ggml_threadpool_chunk_set(struct ggml_threadpool * tp, int value); +int ggml_threadpool_chunk_add(struct ggml_threadpool * tp, int value); + #ifdef __cplusplus } #endif diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ggml/src/ggml-cpu/ggml-cpu-quants.c deleted file mode 100644 index 91a81bdc3..000000000 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.c +++ /dev/null @@ -1,13026 +0,0 @@ -#define GGML_COMMON_IMPL_C -#include "ggml-common.h" - -#include "ggml-quants.h" -#include "ggml-cpu-quants.h" -#include "ggml-impl.h" -#include "ggml-cpu-impl.h" -#include "ggml-cpu.h" - -#include -#include -#include -#include -#include // for qsort -#include // for GGML_ASSERT - -#define GROUP_MAX_EPS 1e-15f -#define GROUP_MAX_EPS_IQ3_XXS 1e-8f -#define GROUP_MAX_EPS_IQ2_S 1e-8f -#define GROUP_MAX_EPS_IQ1_M 1e-7f -#define GROUP_MAX_EPS_IQ1_S 1e-12f - -#if defined(_MSC_VER) -// disable "possible loss of data" to avoid warnings for hundreds of casts -// we should just be careful :) -#pragma warning(disable: 4244 4267) -#endif - -#define UNUSED GGML_UNUSED - -// some compilers don't provide _mm256_set_m128i, e.g. gcc 7 -#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) - -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) -// multiply int8_t, add results pairwise twice -static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { - // Get absolute values of x vectors - const __m128i ax = _mm_sign_epi8(x, x); - // Sign the values of the y vectors - const __m128i sy = _mm_sign_epi8(y, x); - // Perform multiplication and create 16-bit values - const __m128i dot = _mm_maddubs_epi16(ax, sy); - const __m128i ones = _mm_set1_epi16(1); - return _mm_madd_epi16(ones, dot); -} - -#if __AVX__ || __AVX2__ || __AVX512F__ -// horizontally add 8 floats -static inline float hsum_float_8(const __m256 x) { - __m128 res = _mm256_extractf128_ps(x, 1); - res = _mm_add_ps(res, _mm256_castps256_ps128(x)); - res = _mm_add_ps(res, _mm_movehl_ps(res, res)); - res = _mm_add_ss(res, _mm_movehdup_ps(res)); - return _mm_cvtss_f32(res); -} - -// horizontally add 8 int32_t -static inline int hsum_i32_8(const __m256i a) { - const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); - const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); - const __m128i sum64 = _mm_add_epi32(hi64, sum128); - const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); - return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); -} - -// horizontally add 4 int32_t -static inline int hsum_i32_4(const __m128i a) { - const __m128i hi64 = _mm_unpackhi_epi64(a, a); - const __m128i sum64 = _mm_add_epi32(hi64, a); - const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); - return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); -} - -#if defined(__AVX2__) || defined(__AVX512F__) -// spread 32 bits to 32 bytes { 0x00, 0xFF } -static inline __m256i bytes_from_bits_32(const uint8_t * x) { - uint32_t x32; - memcpy(&x32, x, sizeof(uint32_t)); - const __m256i shuf_mask = _mm256_set_epi64x( - 0x0303030303030303, 0x0202020202020202, - 0x0101010101010101, 0x0000000000000000); - __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask); - const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe); - bytes = _mm256_or_si256(bytes, bit_mask); - return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1)); -} - -// Unpack 32 4-bit fields into 32 bytes -// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval -static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) -{ - const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi); - const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); - const __m256i lowMask = _mm256_set1_epi8( 0xF ); - return _mm256_and_si256(lowMask, bytes); -} - -// add int16_t pairwise and return as float vector -static inline __m256 sum_i16_pairs_float(const __m256i x) { - const __m256i ones = _mm256_set1_epi16(1); - const __m256i summed_pairs = _mm256_madd_epi16(ones, x); - return _mm256_cvtepi32_ps(summed_pairs); -} - -static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { -#if defined(__AVX512VNNI__) && defined(__AVX512VL__) - const __m256i zero = _mm256_setzero_si256(); - const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); - return _mm256_cvtepi32_ps(summed_pairs); -#elif defined(__AVXVNNI__) - const __m256i zero = _mm256_setzero_si256(); - const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy); - return _mm256_cvtepi32_ps(summed_pairs); -#else - // Perform multiplication and create 16-bit values - const __m256i dot = _mm256_maddubs_epi16(ax, sy); - return sum_i16_pairs_float(dot); -#endif -} - -// multiply int8_t, add results pairwise twice and return as float vector -static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { -#if __AVXVNNIINT8__ - const __m256i zero = _mm256_setzero_si256(); - const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y); - return _mm256_cvtepi32_ps(summed_pairs); -#else - // Get absolute values of x vectors - const __m256i ax = _mm256_sign_epi8(x, x); - // Sign the values of the y vectors - const __m256i sy = _mm256_sign_epi8(y, x); - return mul_sum_us8_pairs_float(ax, sy); -#endif -} - -static inline __m128i packNibbles( __m256i bytes ) -{ - // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh -#if __AVX512F__ - const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000 - bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh - return _mm256_cvtepi16_epi8(bytes); // abcd_efgh -#else - const __m256i lowByte = _mm256_set1_epi16( 0xFF ); - __m256i high = _mm256_andnot_si256( lowByte, bytes ); - __m256i low = _mm256_and_si256( lowByte, bytes ); - high = _mm256_srli_epi16( high, 4 ); - bytes = _mm256_or_si256( low, high ); - - // Compress uint16_t lanes into bytes - __m128i r0 = _mm256_castsi256_si128( bytes ); - __m128i r1 = _mm256_extracti128_si256( bytes, 1 ); - return _mm_packus_epi16( r0, r1 ); -#endif -} -#elif defined(__AVX__) -static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 ) -{ - // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh - const __m128i lowByte = _mm_set1_epi16( 0xFF ); - __m128i high = _mm_andnot_si128( lowByte, bytes1 ); - __m128i low = _mm_and_si128( lowByte, bytes1 ); - high = _mm_srli_epi16( high, 4 ); - bytes1 = _mm_or_si128( low, high ); - high = _mm_andnot_si128( lowByte, bytes2 ); - low = _mm_and_si128( lowByte, bytes2 ); - high = _mm_srli_epi16( high, 4 ); - bytes2 = _mm_or_si128( low, high ); - - return _mm_packus_epi16( bytes1, bytes2); -} - -static inline __m128i mul_add_epi8_sse(const __m128i x, const __m128i y) { - const __m128i ax = _mm_sign_epi8(x, x); - const __m128i sy = _mm_sign_epi8(y, x); - return _mm_maddubs_epi16(ax, sy); -} - -// spread 32 bits to 32 bytes { 0x00, 0xFF } -static inline __m256i bytes_from_bits_32(const uint8_t * x) { - uint32_t x32; - memcpy(&x32, x, sizeof(uint32_t)); - const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); - const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202); - __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl); - __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh); - const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe); - bytesl = _mm_or_si128(bytesl, bit_mask); - bytesh = _mm_or_si128(bytesh, bit_mask); - bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1)); - bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1)); - return MM256_SET_M128I(bytesh, bytesl); -} - -// Unpack 32 4-bit fields into 32 bytes -// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval -static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) -{ - // Load 16 bytes from memory - __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi); - __m128i tmph = _mm_srli_epi16(tmpl, 4); - const __m128i lowMask = _mm_set1_epi8(0xF); - tmpl = _mm_and_si128(lowMask, tmpl); - tmph = _mm_and_si128(lowMask, tmph); - return MM256_SET_M128I(tmph, tmpl); -} - -// add int16_t pairwise and return as float vector -static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) { - const __m128i ones = _mm_set1_epi16(1); - const __m128i summed_pairsl = _mm_madd_epi16(ones, xl); - const __m128i summed_pairsh = _mm_madd_epi16(ones, xh); - const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl); - return _mm256_cvtepi32_ps(summed_pairs); -} - -static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { - const __m128i axl = _mm256_castsi256_si128(ax); - const __m128i axh = _mm256_extractf128_si256(ax, 1); - const __m128i syl = _mm256_castsi256_si128(sy); - const __m128i syh = _mm256_extractf128_si256(sy, 1); - // Perform multiplication and create 16-bit values - const __m128i dotl = _mm_maddubs_epi16(axl, syl); - const __m128i doth = _mm_maddubs_epi16(axh, syh); - return sum_i16_pairs_float(doth, dotl); -} - -// multiply int8_t, add results pairwise twice and return as float vector -static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { - const __m128i xl = _mm256_castsi256_si128(x); - const __m128i xh = _mm256_extractf128_si256(x, 1); - const __m128i yl = _mm256_castsi256_si128(y); - const __m128i yh = _mm256_extractf128_si256(y, 1); - // Get absolute values of x vectors - const __m128i axl = _mm_sign_epi8(xl, xl); - const __m128i axh = _mm_sign_epi8(xh, xh); - // Sign the values of the y vectors - const __m128i syl = _mm_sign_epi8(yl, xl); - const __m128i syh = _mm_sign_epi8(yh, xh); - // Perform multiplication and create 16-bit values - const __m128i dotl = _mm_maddubs_epi16(axl, syl); - const __m128i doth = _mm_maddubs_epi16(axh, syh); - return sum_i16_pairs_float(doth, dotl); -} - -// larger version of mul_sum_i8_pairs_float where x and y are each represented by four 128-bit vectors -static inline __m256 mul_sum_i8_quad_float(const __m128i x_1_0, const __m128i x_1_1, const __m128i x_2_0, const __m128i x_2_1, - const __m128i y_1_0, const __m128i y_1_1, const __m128i y_2_0, const __m128i y_2_1) { - const __m128i mone = _mm_set1_epi16(1); - - const __m128i p16_1_0 = mul_add_epi8_sse(x_1_0, y_1_0); - const __m128i p16_1_1 = mul_add_epi8_sse(x_1_1, y_1_1); - const __m128i p16_2_0 = mul_add_epi8_sse(x_2_0, y_2_0); - const __m128i p16_2_1 = mul_add_epi8_sse(x_2_1, y_2_1); - const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, mone); - const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, mone); - const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, mone); - const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, mone); - const __m128i p_1 = _mm_add_epi32(p_1_0, p_1_1); - const __m128i p_2 = _mm_add_epi32(p_2_0, p_2_1); - return _mm256_cvtepi32_ps(MM256_SET_M128I(p_2, p_1)); -} - -// quad fp16 delta calculation -static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const float x1, const float y1) { - // GGML_FP16_TO_FP32 is faster than Intel F16C - return _mm256_set_m128(_mm_set1_ps(GGML_FP16_TO_FP32(x1) * GGML_FP16_TO_FP32(y1)), - _mm_set1_ps(GGML_FP16_TO_FP32(x0) * GGML_FP16_TO_FP32(y0))); -} -#endif -#elif defined(__SSSE3__) -// horizontally add 4x4 floats -static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) { - __m128 res_0 =_mm_hadd_ps(a, b); - __m128 res_1 =_mm_hadd_ps(c, d); - __m128 res =_mm_hadd_ps(res_0, res_1); - res =_mm_hadd_ps(res, res); - res =_mm_hadd_ps(res, res); - - return _mm_cvtss_f32(res); -} -#endif // __AVX__ || __AVX2__ || __AVX512F__ -#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) - -#if defined(__ARM_NEON) || defined(__wasm_simd128__) || defined(__POWER9_VECTOR__) -#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s -#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) -#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s) -#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s) -#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s) -#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s) -#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s) -#define B8(c,s ) B7(c,s, c), B7(c,s, s) - -// precomputed tables for expanding 8bits to 8 bytes: -static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4 -static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 -#endif - -#if defined(__loongarch_sx) - -static __m128i lsx_packs_w(__m128i a, __m128i b) { - __m128i tmp, tmp1; - tmp = __lsx_vsat_w(a, 15); - tmp1 = __lsx_vsat_w(b, 15); - return __lsx_vpickev_h(tmp1, tmp); -} - -static __m128i lsx_packs_h(__m128i a, __m128i b) { - __m128i tmp, tmp1; - tmp = __lsx_vsat_h(a, 7); - tmp1 = __lsx_vsat_h(b, 7); - return __lsx_vpickev_b(tmp1, tmp); -} - -static __m128i lsx_packus_h(__m128i a, __m128i b) { - __m128i tmp, tmp1; - tmp = __lsx_vsat_hu(a, 7); - tmp1 = __lsx_vsat_hu(b, 7); - return __lsx_vpickev_b(tmp1, tmp); -} - -static __m128i lsx_maddubs_h(__m128i a, __m128i b) { - __m128i tmp1, tmp2; - tmp1 = __lsx_vmulwev_h_b(a, b); - tmp2 = __lsx_vmulwod_h_b(a, b); - return __lsx_vsadd_h(tmp1, tmp2); -} - -static __m128i lsx_madd_h(__m128i a, __m128i b) { - __m128i tmp1, tmp2; - tmp1 = __lsx_vmulwev_w_h(a, b); - tmp2 = __lsx_vmulwod_w_h(a, b); - return __lsx_vadd_w(tmp1, tmp2); -} - -static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) { - v4i32 __ret = {d, c, b, a}; - return (__m128i)__ret; -} - -static __m128i lsx_shuffle_b(__m128i a, __m128i b) { - __m128i mask_f, zero, tmp0, tmp2, mask; - int f = 0x8f; - mask_f = __lsx_vreplgr2vr_b(f); - zero = __lsx_vldi(0); - tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits - tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive - mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask - tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones - return __lsx_vshuf_b(a, zero, tmp2); -} - -static __m128i lsx_hadd_h(__m128i a, __m128i b) { - __m128i tmp1 = __lsx_vpickev_h(b, a); - __m128i tmp2 = __lsx_vpickod_h(b, a); - return __lsx_vadd_h(tmp1, tmp2); -} - -static __m128i lsx_hadd_w(__m128i a, __m128i b) { - __m128i tmp1 = __lsx_vpickev_w(b, a); - __m128i tmp2 = __lsx_vpickod_w(b, a); - return __lsx_vadd_w(tmp1, tmp2); -} - -static __m128 lsx_hadd_s(__m128 a, __m128 b) { - __m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a); - __m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a); - - return __lsx_vfadd_s(tmp1, tmp2); -} - -static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) { - __m128 res_0 =lsx_hadd_s(a, b); - __m128 res_1 =lsx_hadd_s(c, d); - __m128 res =lsx_hadd_s(res_0, res_1); - res =lsx_hadd_s(res, res); - res =lsx_hadd_s(res, res); - - return ((v4f32)res)[0]; -} -#endif - -#if defined(__loongarch_asx) - -#ifdef __clang__ -#define VREGS_PREFIX "$vr" -#define XREGS_PREFIX "$xr" -#else // GCC -#define VREGS_PREFIX "$f" -#define XREGS_PREFIX "$f" -#endif -#define __ALL_REGS "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31" -// Convert __m128i to __m256i -static inline __m256i ____m256i(__m128i in) { - __m256i out = __lasx_xvldi(0); - __asm__ volatile ( - ".irp i," __ALL_REGS "\n\t" - " .ifc %[out], " XREGS_PREFIX"\\i \n\t" - " .irp j," __ALL_REGS "\n\t" - " .ifc %[in], " VREGS_PREFIX "\\j \n\t" - " xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t" - " .endif \n\t" - " .endr \n\t" - " .endif \n\t" - ".endr \n\t" - : [out] "+f" (out) : [in] "f" (in) - ); - return out; -} -// Convert two __m128i to __m256i -static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) { - __m256i out; - __asm__ volatile ( - ".irp i," __ALL_REGS "\n\t" - " .ifc %[hi], " VREGS_PREFIX "\\i \n\t" - " .irp j," __ALL_REGS "\n\t" - " .ifc %[lo], " VREGS_PREFIX "\\j \n\t" - " xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t" - " .endif \n\t" - " .endr \n\t" - " .endif \n\t" - ".endr \n\t" - ".ifnc %[out], %[hi] \n\t" - ".irp i," __ALL_REGS "\n\t" - " .ifc %[out], " XREGS_PREFIX "\\i \n\t" - " .irp j," __ALL_REGS "\n\t" - " .ifc %[hi], " VREGS_PREFIX "\\j \n\t" - " xvori.b $xr\\i, $xr\\j, 0 \n\t" - " .endif \n\t" - " .endr \n\t" - " .endif \n\t" - ".endr \n\t" - ".endif \n\t" - : [out] "=f" (out), [hi] "+f" (inhi) - : [lo] "f" (inlo) - ); - return out; -} -// Convert __m256i low part to __m128i -static inline __m128i lasx_extracti128_lo(__m256i in) { - __m128i out; - __asm__ volatile ( - ".ifnc %[out], %[in] \n\t" - ".irp i," __ALL_REGS "\n\t" - " .ifc %[out], " VREGS_PREFIX "\\i \n\t" - " .irp j," __ALL_REGS "\n\t" - " .ifc %[in], " XREGS_PREFIX "\\j \n\t" - " vori.b $vr\\i, $vr\\j, 0 \n\t" - " .endif \n\t" - " .endr \n\t" - " .endif \n\t" - ".endr \n\t" - ".endif \n\t" - : [out] "=f" (out) : [in] "f" (in) - ); - return out; -} -// Convert __m256i high part to __m128i -static inline __m128i lasx_extracti128_hi(__m256i in) { - __m128i out; - __asm__ volatile ( - ".irp i," __ALL_REGS "\n\t" - " .ifc %[out], " VREGS_PREFIX "\\i \n\t" - " .irp j," __ALL_REGS "\n\t" - " .ifc %[in], " XREGS_PREFIX "\\j \n\t" - " xvpermi.q $xr\\i, $xr\\j, 0x11 \n\t" - " .endif \n\t" - " .endr \n\t" - " .endif \n\t" - ".endr \n\t" - : [out] "=f" (out) : [in] "f" (in) - ); - return out; -} - -static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1, int e0) { - v8i32 __ret = {e0, e1, e2, e3, e4, e5, e6, e7}; - return (__m256i)__ret; -} - -static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) { - v4i64 __ret = {d, c, b, a}; - return (__m256i)__ret; -} - -static __m256i lasx_insertf128( __m128i x, __m128i y) { - return lasx_set_q(x, y); -} - -static __m256i lasx_shuffle_b(__m256i a, __m256i b) { - __m256i mask_f, zero, tmp0, tmp2, mask; - int f = 0x8f; - mask_f = __lasx_xvreplgr2vr_b(f); - zero = __lasx_xvldi(0); - tmp0 = __lasx_xvand_v(b, mask_f); // get mask with low 4 bit and sign bits - tmp0 = __lasx_xvori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive - mask = __lasx_xvsle_b(zero, tmp0); // if mask >= 0, set mask - tmp2 = __lasx_xvand_v(tmp0, mask); // maskout the in2 < ones - return __lasx_xvshuf_b(a, zero, tmp2); -} - -static __m256i lasx_extu8_16(__m128i a) { - return __lasx_vext2xv_hu_bu(____m256i(a)); -} - -static __m256i lasx_ext8_16(__m128i a) { - return __lasx_vext2xv_h_b(____m256i(a)); -} - -static __m256i lasx_ext16_32(__m128i a) { - return __lasx_vext2xv_w_h(____m256i(a)); -} - -static __m128i lasx_extracti128( __m256i a, int pos) { - __m128i ret; - if( pos == 0) - { - ret = lasx_extracti128_lo(a); - } else { - ret = lasx_extracti128_hi(a); - } - return ret; -} - -static __m128 lasx_extractf128( __m256 a, int pos) { - __m128 ret; - if( pos == 0) - { - ret = (__m128)lasx_extracti128_lo((__m256i)a); - } else { - ret = (__m128)lasx_extracti128_hi((__m256i)a); - } - return ret; -} - -static __m256i lasx_maddubs_h(__m256i a, __m256i b) { - __m256i tmp1, tmp2; - tmp1 = __lasx_xvmulwev_h_b(a, b); - tmp2 = __lasx_xvmulwod_h_b(a, b); - return __lasx_xvsadd_h(tmp1, tmp2); -} - -static __m256i lasx_madd_h(__m256i a, __m256i b) { - __m256i tmp1, tmp2; - tmp1 = __lasx_xvmulwev_w_h(a, b); - tmp2 = __lasx_xvmulwod_w_h(a, b); - return __lasx_xvadd_w(tmp1, tmp2); -} - -static __m256i lasx_packs_w(__m256i a, __m256i b) { - __m256i tmp, tmp1; - tmp = __lasx_xvsat_w(a, 15); - tmp1 = __lasx_xvsat_w(b, 15); - return __lasx_xvpickev_h(tmp1, tmp); -} - -static __m256i lasx_packs_h(__m256i a, __m256i b) { - __m256i tmp, tmp1; - tmp = __lasx_xvsat_h(a, 7); - tmp1 = __lasx_xvsat_h(b, 7); - return __lasx_xvpickev_b(tmp1, tmp); -} - -static inline __m256i lasx_madd_h_b(__m256i a, __m256i b) { - __m256i tmp1, tmp2; - tmp1 = __lasx_xvmulwev_h_b(a, b); - tmp2 = __lasx_xvmulwod_h_b(a, b); - return __lasx_xvadd_h(tmp1, tmp2); -} - -static inline __m256i lasx_xvrepl128vei_h(__m256i a, const unsigned int b) { - switch (b) { - case 0: return __lasx_xvrepl128vei_h(a, 0); - case 1: return __lasx_xvrepl128vei_h(a, 1); - case 2: return __lasx_xvrepl128vei_h(a, 2); - case 3: return __lasx_xvrepl128vei_h(a, 3); - case 4: return __lasx_xvrepl128vei_h(a, 4); - case 5: return __lasx_xvrepl128vei_h(a, 5); - case 6: return __lasx_xvrepl128vei_h(a, 6); - case 7: return __lasx_xvrepl128vei_h(a, 7); - default: __builtin_unreachable(); - } -} - -static inline __m256i lasx_xvandi_b_bit(__m256i a, const unsigned int b) { - switch (b) { - case 0: return __lasx_xvandi_b(a, 1 << 0); - case 1: return __lasx_xvandi_b(a, 1 << 1); - case 2: return __lasx_xvandi_b(a, 1 << 2); - case 3: return __lasx_xvandi_b(a, 1 << 3); - case 4: return __lasx_xvandi_b(a, 1 << 4); - case 5: return __lasx_xvandi_b(a, 1 << 5); - case 6: return __lasx_xvandi_b(a, 1 << 6); - case 7: return __lasx_xvandi_b(a, 1 << 7); - default: __builtin_unreachable(); - } -} - -// multiply int8_t, add results pairwise twice -static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { - // Get absolute values of x vectors - const __m128i ax = __lsx_vsigncov_b(x, x); - // Sign the values of the y vectors - const __m128i sy = __lsx_vsigncov_b(x, y); - // Perform multiplication and create 16-bit values - const __m128i dot = lsx_maddubs_h(ax, sy); - const __m128i ones = __lsx_vreplgr2vr_h(1); - return lsx_madd_h(ones, dot); -} - -// horizontally add 8 floats -static inline float hsum_float_8(const __m256 x) { - __m128 res = lasx_extractf128(x, 1); - res = __lsx_vfadd_s(res, lasx_extractf128(x, 0)); - res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res)); - res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0)); - return ((v4f32)res)[0]; -} - -// horizontally add 8 int32_t -static inline int hsum_i32_8(const __m256i a) { - - __m256i tmp1 = __lasx_xvpermi_q(a, a, 0x11); - __m256i tmp2 = __lasx_xvpermi_q(a, a, 0x00); - - __m128i tmp1_128 = lasx_extracti128_lo(tmp1); - __m128i tmp2_128 = lasx_extracti128_lo(tmp2); - - __m128i sum128 = __lsx_vadd_w(tmp1_128, tmp2_128); - - __m128i ev = __lsx_vpickev_w(sum128, sum128); - __m128i od = __lsx_vpickod_w(sum128, sum128); - __m128i sum64 = __lsx_vadd_w(ev, od); - - int sum64_1, sum64_2; - sum64_1 = __lsx_vpickve2gr_w(sum64, 0); - sum64_2 = __lsx_vpickve2gr_w(sum64, 1); - - return sum64_1 + sum64_2; -} - -// horizontally add 4 int32_t -static inline int hsum_i32_4(const __m128i a) { - __m128i ev = __lsx_vpickev_w(a, a); - __m128i od = __lsx_vpickod_w(a, a); - __m128i sum64 = __lsx_vadd_w(ev, od); - - int sum64_1, sum64_2; - sum64_1 = __lsx_vpickve2gr_w(sum64, 0); - sum64_2 = __lsx_vpickve2gr_w(sum64, 1); - - return sum64_1 + sum64_2; -} - -// spread 32 bits to 32 bytes { 0x00, 0xFF } -static inline __m256i bytes_from_bits_32(const uint8_t * x) { - - uint32_t x32; - memcpy(&x32, x, sizeof(uint32_t)); - const __m256i shuf_mask = lasx_set_d( - 0x0303030303030303, 0x0202020202020202, - 0x0101010101010101, 0x0000000000000000); - - __m256i bytes = lasx_shuffle_b(__lasx_xvreplgr2vr_w(x32), shuf_mask); - const __m256i bit_mask = __lasx_xvreplgr2vr_d(0x7fbfdfeff7fbfdfe); - bytes = __lasx_xvor_v(bytes, bit_mask); - return __lasx_xvseq_b(bytes, __lasx_xvreplgr2vr_d(-1)); -} - -// Unpack 32 4-bit fields into 32 bytes -// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval -static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) { - const __m128i lo = __lsx_vld((const __m128i *)rsi, 0); - __m128i hi = __lsx_vsrli_h(lo, 4); - return __lasx_xvandi_b(lasx_insertf128(hi, lo), 0xf); -} - -// add int16_t pairwise and return as float vector -static inline __m256 sum_i16_pairs_float(const __m256i x) { - __m256i v = __lasx_xvpackod_h(x, x); - __m256i summed_pairs = __lasx_xvaddwev_w_h(x, v); - return __lasx_xvffint_s_w(summed_pairs); -} - -static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { - // Perform multiplication and create 16-bit values - const __m256i dot = lasx_maddubs_h(ax, sy); - return sum_i16_pairs_float(dot); -} - -// multiply int8_t, add results pairwise twice and return as float vector -static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { - const __m256i dot = lasx_madd_h_b(x, y); - return sum_i16_pairs_float(dot); -} - -static inline __m128i packNibbles( __m256i bytes ) { - // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh - const __m256i lowByte = __lasx_xvreplgr2vr_h(0xFF); - __m256i high = __lasx_xvandn_v(lowByte, bytes); - __m256i low = __lasx_xvand_v(lowByte, bytes); - high = __lasx_xvsrli_h(high, 4); - bytes = __lasx_xvor_v(low, high); - // Compress uint16_t lanes into bytes - __m128i *r0 = (__m128i *)&bytes; - __m256i tmp_h128 = __lasx_xvpermi_q(bytes, bytes, 0x11); - __m128i *r1 = (__m128i *)&tmp_h128; - - __m128i zero = __lsx_vldi(0); - __m128i tmp, tmp2, tmp3; - - tmp = __lsx_vmax_h(zero, *r0); - tmp2 = __lsx_vsat_hu(tmp, 7); - - tmp = __lsx_vmax_h(zero, *r1); - tmp3 = __lsx_vsat_hu(tmp, 7); - return __lsx_vpickev_b(tmp3, tmp2); -} -#endif //__loongarch_asx - -void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { - quantize_row_q4_0_ref(x, y, k); -} - -void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { - quantize_row_q4_1_ref(x, y, k); -} - -void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { - quantize_row_q5_0_ref(x, y, k); -} - -void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { - quantize_row_q5_1_ref(x, y, k); -} - -void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { - assert(QK8_0 == 32); - assert(k % QK8_0 == 0); - const int nb = k / QK8_0; - - block_q8_0 * GGML_RESTRICT y = vy; - -#if defined(__ARM_NEON) - for (int i = 0; i < nb; i++) { - float32x4_t srcv [8]; - float32x4_t asrcv[8]; - float32x4_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); - - const float amax = vmaxvq_f32(amaxv[0]); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - - for (int j = 0; j < 8; j++) { - const float32x4_t v = vmulq_n_f32(srcv[j], id); - const int32x4_t vi = vcvtnq_s32_f32(v); - - y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); - } - } -#elif defined __wasm_simd128__ - for (int i = 0; i < nb; i++) { - v128_t srcv [8]; - v128_t asrcv[8]; - v128_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); - - const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), - wasm_f32x4_extract_lane(amaxv[0], 1)), - MAX(wasm_f32x4_extract_lane(amaxv[0], 2), - wasm_f32x4_extract_lane(amaxv[0], 3))); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - - for (int j = 0; j < 8; j++) { - const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); - const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); - - y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); - y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); - y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); - y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); - } - } -#elif defined(__AVX2__) || defined(__AVX__) - for (int i = 0; i < nb; i++) { - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x ); - __m256 v1 = _mm256_loadu_ps( x + 8 ); - __m256 v2 = _mm256_loadu_ps( x + 16 ); - __m256 v3 = _mm256_loadu_ps( x + 24 ); - x += 32; - - // Compute max(abs(e)) for the block - const __m256 signBit = _mm256_set1_ps( -0.0f ); - __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); - - __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); - max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); - max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); - const float maxScalar = _mm_cvtss_f32( max4 ); - - // Quantize these floats - const float d = maxScalar / 127.f; - y[i].d = GGML_FP32_TO_FP16(d); - const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; - const __m256 mul = _mm256_set1_ps( id ); - - // Apply the multiplier - v0 = _mm256_mul_ps( v0, mul ); - v1 = _mm256_mul_ps( v1, mul ); - v2 = _mm256_mul_ps( v2, mul ); - v3 = _mm256_mul_ps( v3, mul ); - - // Round to nearest integer - v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); - v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); - v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); - v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); - - // Convert floats to integers - __m256i i0 = _mm256_cvtps_epi32( v0 ); - __m256i i1 = _mm256_cvtps_epi32( v1 ); - __m256i i2 = _mm256_cvtps_epi32( v2 ); - __m256i i3 = _mm256_cvtps_epi32( v3 ); - -#if defined(__AVX2__) - // Convert int32 to int16 - i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 - i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 - // Convert int16 to int8 - i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 - - // We got our precious signed bytes, but the order is now wrong - // These AVX2 pack instructions process 16-byte pieces independently - // The following instruction is fixing the order - const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); - i0 = _mm256_permutevar8x32_epi32( i0, perm ); - - _mm256_storeu_si256((__m256i *)y[i].qs, i0); -#else - // Since we don't have in AVX some necessary functions, - // we split the registers in half and call AVX2 analogs from SSE - __m128i ni0 = _mm256_castsi256_si128( i0 ); - __m128i ni1 = _mm256_extractf128_si256( i0, 1); - __m128i ni2 = _mm256_castsi256_si128( i1 ); - __m128i ni3 = _mm256_extractf128_si256( i1, 1); - __m128i ni4 = _mm256_castsi256_si128( i2 ); - __m128i ni5 = _mm256_extractf128_si256( i2, 1); - __m128i ni6 = _mm256_castsi256_si128( i3 ); - __m128i ni7 = _mm256_extractf128_si256( i3, 1); - - // Convert int32 to int16 - ni0 = _mm_packs_epi32( ni0, ni1 ); - ni2 = _mm_packs_epi32( ni2, ni3 ); - ni4 = _mm_packs_epi32( ni4, ni5 ); - ni6 = _mm_packs_epi32( ni6, ni7 ); - // Convert int16 to int8 - ni0 = _mm_packs_epi16( ni0, ni2 ); - ni4 = _mm_packs_epi16( ni4, ni6 ); - - _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); - _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); -#endif - } -#elif defined(__riscv_v_intrinsic) - - size_t vl = QK8_0; - - for (int i = 0; i < nb; i++) { - // load elements - vfloat32m8_t v_x = __riscv_vle32_v_f32m8(x+i*QK8_0, vl); - - vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl); - vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); - vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl); - float amax = __riscv_vfmv_f_s_f32m1_f32(vmax); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - - vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl); - - // convert to integer - vint16m4_t vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl); - vint8m2_t vs = __riscv_vncvt_x_x_w_i8m2(vi, vl); - - // store result - __riscv_vse8_v_i8m2(y[i].qs , vs, vl); - } - -#elif defined(__POWER9_VECTOR__) - for (int i = 0; i < nb; i++) { - vector float srcv [8]; - vector float asrcv[8]; - vector float amaxv[8]; - vector signed int vi[8]; - - for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]); - - const float amax = MAX(MAX(vec_extract(amaxv[0], 0), - vec_extract(amaxv[0], 1)), - MAX(vec_extract(amaxv[0], 2), - vec_extract(amaxv[0], 3))); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - const vector float vid = vec_splats(id); - - y[i].d = GGML_FP32_TO_FP16(d); - - for (int j = 0; j < 8; j++) { - const vector float v = vec_round(vec_mul(srcv[j], vid)); - vi[j] = vec_cts(v, 0); - } - vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])), 0, &y[i].qs[0]); - vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]); - } - -#elif defined(__loongarch_asx) - for (int i = 0; i < nb; i++) { - __m256 v0 = (__m256)__lasx_xvld( x , 0); - __m256 v1 = (__m256)__lasx_xvld( x , 32); - __m256 v2 = (__m256)__lasx_xvld( x , 64); - __m256 v3 = (__m256)__lasx_xvld( x , 96); - x += 32; - - // Compute max(abs(e)) for the block - const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f ); - __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 ); - max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) ); - max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) ); - max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) ); - - __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs , 0) ); - max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) ); - __m128 tmp = max4; - max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 )); - const float max_scalar = ((v4f32)max4)[0]; - - // Quantize these floats - const float d = max_scalar / 127.f; - y[i].d = GGML_FP32_TO_FP16(d); - const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; - const __m256 mul = (__m256)__lasx_xvreplfr2vr_s( id ); - - // Apply the multiplier - v0 = __lasx_xvfmul_s( v0, mul ); - v1 = __lasx_xvfmul_s( v1, mul ); - v2 = __lasx_xvfmul_s( v2, mul ); - v3 = __lasx_xvfmul_s( v3, mul ); - - // Round to nearest integer - __m256i i0 = __lasx_xvftintrne_w_s( v0 ); - __m256i i1 = __lasx_xvftintrne_w_s( v1 ); - __m256i i2 = __lasx_xvftintrne_w_s( v2 ); - __m256i i3 = __lasx_xvftintrne_w_s( v3 ); - - __m128i ni0 = lasx_extracti128( i0, 0 ); - __m128i ni1 = lasx_extracti128( i0, 1); - __m128i ni2 = lasx_extracti128( i1, 0); - __m128i ni3 = lasx_extracti128( i1, 1); - __m128i ni4 = lasx_extracti128( i2, 0); - __m128i ni5 = lasx_extracti128( i2, 1); - __m128i ni6 = lasx_extracti128( i3, 0); - __m128i ni7 = lasx_extracti128( i3, 1); - - // Convert int32 to int16 - ni0 = lsx_packs_w( ni0, ni1 ); - ni2 = lsx_packs_w( ni2, ni3 ); - ni4 = lsx_packs_w( ni4, ni5 ); - ni6 = lsx_packs_w( ni6, ni7 ); - // Convert int16 to int8 - ni0 = lsx_packs_h( ni0, ni2 ); - ni4 = lsx_packs_h( ni4, ni6 ); - - __lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0); - __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0); - - } -#elif defined(__VXE__) || defined(__VXE2__) - for (int i = 0; i < nb; i++) { - __vector float srcv [8]; - __vector float asrcv[8]; - __vector float amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); - for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]); - - const float amax = MAX(MAX(vec_extract(amaxv[0], 0), - vec_extract(amaxv[0], 1)), - MAX(vec_extract(amaxv[0], 2), - vec_extract(amaxv[0], 3))); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f / d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - - for (int j = 0; j < 8; j++) { - const __vector float v = vec_mul(srcv[j], vec_splats(id)); - const __vector int32_t vi = vec_signed(v); - - y[i].qs[4*j + 0] = vec_extract(vi, 0); - y[i].qs[4*j + 1] = vec_extract(vi, 1); - y[i].qs[4*j + 2] = vec_extract(vi, 2); - y[i].qs[4*j + 3] = vec_extract(vi, 3); - } - } -#else - GGML_UNUSED(nb); - // scalar - quantize_row_q8_0_ref(x, y, k); -#endif -} - -void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { - assert(k % QK8_1 == 0); - const int nb = k / QK8_1; - - block_q8_1 * GGML_RESTRICT y = vy; - -#if defined(__ARM_NEON) - for (int i = 0; i < nb; i++) { - float32x4_t srcv [8]; - float32x4_t asrcv[8]; - float32x4_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); - - const float amax = vmaxvq_f32(amaxv[0]); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - - int32x4_t accv = vdupq_n_s32(0); - - for (int j = 0; j < 8; j++) { - const float32x4_t v = vmulq_n_f32(srcv[j], id); - const int32x4_t vi = vcvtnq_s32_f32(v); - - y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); - - accv = vaddq_s32(accv, vi); - } - - y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv)); - } -#elif defined __wasm_simd128__ - for (int i = 0; i < nb; i++) { - v128_t srcv [8]; - v128_t asrcv[8]; - v128_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); - - const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), - wasm_f32x4_extract_lane(amaxv[0], 1)), - MAX(wasm_f32x4_extract_lane(amaxv[0], 2), - wasm_f32x4_extract_lane(amaxv[0], 3))); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - - v128_t accv = wasm_i32x4_splat(0); - - for (int j = 0; j < 8; j++) { - const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); - const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); - - y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); - y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); - y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); - y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); - - accv = wasm_i32x4_add(accv, vi); - } - - y[i].s = GGML_FP32_TO_FP16( - d * (wasm_i32x4_extract_lane(accv, 0) + - wasm_i32x4_extract_lane(accv, 1) + - wasm_i32x4_extract_lane(accv, 2) + - wasm_i32x4_extract_lane(accv, 3))); - } -#elif defined(__AVX2__) || defined(__AVX__) - for (int i = 0; i < nb; i++) { - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x ); - __m256 v1 = _mm256_loadu_ps( x + 8 ); - __m256 v2 = _mm256_loadu_ps( x + 16 ); - __m256 v3 = _mm256_loadu_ps( x + 24 ); - x += 32; - - // Compute max(abs(e)) for the block - const __m256 signBit = _mm256_set1_ps( -0.0f ); - __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); - - __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); - max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); - max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); - const float max_scalar = _mm_cvtss_f32( max4 ); - - // Quantize these floats - const float d = max_scalar / 127.f; - y[i].d = GGML_FP32_TO_FP16(d); - const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; - const __m256 mul = _mm256_set1_ps( id ); - - // Apply the multiplier - v0 = _mm256_mul_ps( v0, mul ); - v1 = _mm256_mul_ps( v1, mul ); - v2 = _mm256_mul_ps( v2, mul ); - v3 = _mm256_mul_ps( v3, mul ); - - // Round to nearest integer - v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); - v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); - v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); - v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); - - // Convert floats to integers - __m256i i0 = _mm256_cvtps_epi32( v0 ); - __m256i i1 = _mm256_cvtps_epi32( v1 ); - __m256i i2 = _mm256_cvtps_epi32( v2 ); - __m256i i3 = _mm256_cvtps_epi32( v3 ); - -#if defined(__AVX2__) - // Compute the sum of the quants and set y[i].s - y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3)))); - - // Convert int32 to int16 - i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 - i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 - // Convert int16 to int8 - i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 - - // We got our precious signed bytes, but the order is now wrong - // These AVX2 pack instructions process 16-byte pieces independently - // The following instruction is fixing the order - const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); - i0 = _mm256_permutevar8x32_epi32( i0, perm ); - - _mm256_storeu_si256((__m256i *)y[i].qs, i0); -#else - // Since we don't have in AVX some necessary functions, - // we split the registers in half and call AVX2 analogs from SSE - __m128i ni0 = _mm256_castsi256_si128( i0 ); - __m128i ni1 = _mm256_extractf128_si256( i0, 1); - __m128i ni2 = _mm256_castsi256_si128( i1 ); - __m128i ni3 = _mm256_extractf128_si256( i1, 1); - __m128i ni4 = _mm256_castsi256_si128( i2 ); - __m128i ni5 = _mm256_extractf128_si256( i2, 1); - __m128i ni6 = _mm256_castsi256_si128( i3 ); - __m128i ni7 = _mm256_extractf128_si256( i3, 1); - - // Compute the sum of the quants and set y[i].s - const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3)); - const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7)); - y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_4(_mm_add_epi32(s0, s1))); - - // Convert int32 to int16 - ni0 = _mm_packs_epi32( ni0, ni1 ); - ni2 = _mm_packs_epi32( ni2, ni3 ); - ni4 = _mm_packs_epi32( ni4, ni5 ); - ni6 = _mm_packs_epi32( ni6, ni7 ); - // Convert int16 to int8 - ni0 = _mm_packs_epi16( ni0, ni2 ); - ni4 = _mm_packs_epi16( ni4, ni6 ); - - _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); - _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); -#endif - } -#elif defined(__riscv_v_intrinsic) - - size_t vl = QK8_1; - - for (int i = 0; i < nb; i++) { - // load elements - vfloat32m8_t v_x = __riscv_vle32_v_f32m8(x+i*QK8_1, vl); - - vfloat32m8_t vfabs = __riscv_vfabs_v_f32m8(v_x, vl); - vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl); - vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m8_f32m1(vfabs, tmp, vl); - float amax = __riscv_vfmv_f_s_f32m1_f32(vmax); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - - vfloat32m8_t x0 = __riscv_vfmul_vf_f32m8(v_x, id, vl); - - // convert to integer - vint16m4_t vi = __riscv_vfncvt_x_f_w_i16m4(x0, vl); - vint8m2_t vs = __riscv_vncvt_x_x_w_i8m2(vi, vl); - - // store result - __riscv_vse8_v_i8m2(y[i].qs , vs, vl); - - // compute sum for y[i].s - vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl); - vint16m1_t vwrs = __riscv_vwredsum_vs_i8m2_i16m1(vs, tmp2, vl); - - // set y[i].s - int sum = __riscv_vmv_x_s_i16m1_i16(vwrs); - y[i].s = GGML_FP32_TO_FP16(sum*d); - } - -#elif defined(__POWER9_VECTOR__) - for (int i = 0; i < nb; i++) { - vector float srcv [8]; - vector float asrcv[8]; - vector float amaxv[8]; - vector signed int vi[8]; - - for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]); - - const float amax = MAX(MAX(vec_extract(amaxv[0], 0), - vec_extract(amaxv[0], 1)), - MAX(vec_extract(amaxv[0], 2), - vec_extract(amaxv[0], 3))); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - const vector float vid = vec_splats(id); - - y[i].d = GGML_FP32_TO_FP16(d); - - vector int accv = vec_splats(0); - - for (int j = 0; j < 8; j++) { - const vector float v = vec_round(vec_mul(srcv[j], vid)); - vi[j] = vec_cts(v, 0); - - accv = vec_add(accv, vi[j]); - } - vec_xst(vec_pack(vec_pack(vi[0], vi[1]), vec_pack(vi[2], vi[3])), 0, &y[i].qs[0]); - vec_xst(vec_pack(vec_pack(vi[4], vi[5]), vec_pack(vi[6], vi[7])), 16, &y[i].qs[0]); - - accv = vec_add(accv, vec_sld(accv, accv, 4)); - accv = vec_add(accv, vec_sld(accv, accv, 8)); - y[i].s = GGML_FP32_TO_FP16(d * vec_extract(accv, 0)); - } - -#elif defined(__loongarch_asx) - for (int i = 0; i < nb; i++) { - __m256 v0 = (__m256)__lasx_xvld( x , 0 ); - __m256 v1 = (__m256)__lasx_xvld( x , 32 ); - __m256 v2 = (__m256)__lasx_xvld( x , 64 ); - __m256 v3 = (__m256)__lasx_xvld( x , 96 ); - x += 32; - - // Compute max(abs(e)) for the block - const __m256 sign_bit = __lasx_xvreplfr2vr_s( -0.0f ); - __m256 max_abs = (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v0 ); - max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v1 ) ); - max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v2 ) ); - max_abs = __lasx_xvfmax_s( max_abs, (__m256)__lasx_xvandn_v( (__m256i)sign_bit, (__m256i)v3 ) ); - - __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs, 0) ); - max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) ); - __m128 tmp = max4; - max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 )); - const float max_scalar = ((v4f32)max4)[0]; - - // Quantize these floats - const float d = max_scalar / 127.f; - y[i].d = GGML_FP32_TO_FP16(d); - const float id = ( max_scalar != 0.0f ) ? 127.f / max_scalar : 0.0f; - const __m256 mul = __lasx_xvreplfr2vr_s( id ); - - // Apply the multiplier - v0 = __lasx_xvfmul_s( v0, mul ); - v1 = __lasx_xvfmul_s( v1, mul ); - v2 = __lasx_xvfmul_s( v2, mul ); - v3 = __lasx_xvfmul_s( v3, mul ); - - // Round to nearest integer - __m256i i0 = __lasx_xvftintrne_w_s( v0 ); - __m256i i1 = __lasx_xvftintrne_w_s( v1 ); - __m256i i2 = __lasx_xvftintrne_w_s( v2 ); - __m256i i3 = __lasx_xvftintrne_w_s( v3 ); - - __m128i ni0 = lasx_extracti128(i0, 0); - __m128i ni1 = lasx_extracti128( i0, 1); - __m128i ni2 = lasx_extracti128( i1, 0); - __m128i ni3 = lasx_extracti128( i1, 1); - __m128i ni4 = lasx_extracti128( i2, 0 ); - __m128i ni5 = lasx_extracti128( i2, 1); - __m128i ni6 = lasx_extracti128( i3, 0); - __m128i ni7 = lasx_extracti128( i3, 1); - - // Compute the sum of the quants and set y[i].s - const __m128i s0 = __lsx_vadd_w(__lsx_vadd_w(ni0, ni1), __lsx_vadd_w(ni2, ni3)); - const __m128i s1 = __lsx_vadd_w(__lsx_vadd_w(ni4, ni5), __lsx_vadd_w(ni6, ni7)); - y[i].s = GGML_FP32_TO_FP16(d * hsum_i32_4(__lsx_vadd_w(s0, s1))); - - // Convert int32 to int16 - ni0 = lsx_packs_w( ni0, ni1 ); - ni2 = lsx_packs_w( ni2, ni3 ); - ni4 = lsx_packs_w( ni4, ni5 ); - ni6 = lsx_packs_w( ni6, ni7 ); - // Convert int16 to int8 - ni0 = lsx_packs_h( ni0, ni2 ); - ni4 = lsx_packs_h( ni4, ni6 ); - - __lsx_vst(ni0, (__m128i *)(y[i].qs + 0), 0); - __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0); - } -#elif defined(__VXE__) || defined(__VXE2__) - for (int i = 0; i < nb; i++) { - __vector float srcv [8]; - __vector float asrcv[8]; - __vector float amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); - for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]); - - const float amax = MAX(MAX(vec_extract(amaxv[0], 0), - vec_extract(amaxv[0], 1)), - MAX(vec_extract(amaxv[0], 2), - vec_extract(amaxv[0], 3))); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f / d : 0.0f; - - y[i].d = GGML_FP32_TO_FP16(d); - - __vector int32_t acc = vec_splats(0); - - for (int j = 0; j < 8; j++) { - const __vector float v = vec_mul(srcv[j], vec_splats(id)); - const __vector int32_t vi = vec_signed(v); - - y[i].qs[4*j + 0] = vec_extract(vi, 0); - y[i].qs[4*j + 1] = vec_extract(vi, 1); - y[i].qs[4*j + 2] = vec_extract(vi, 2); - y[i].qs[4*j + 3] = vec_extract(vi, 3); - - acc = vec_add(acc, vi); - } - - y[i].s = GGML_FP32_TO_FP16(d * (acc[0] + acc[1] + acc[2] + acc[3])); - } -#else - GGML_UNUSED(nb); - // scalar - quantize_row_q8_1_ref(x, y, k); -#endif -} - -// -// 2-6 bit quantization in super-blocks -// - -// -// ===================== Helper functions -// -static inline int nearest_int(float fval) { - assert(fabsf(fval) <= 4194303.f); - float val = fval + 12582912.f; - int i; memcpy(&i, &val, sizeof(int)); - return (i & 0x007fffff) - 0x00400000; -} - -static float make_qx_quants(int n, int nmax, const float * GGML_RESTRICT x, int8_t * GGML_RESTRICT L, int rmse_type, - const float * GGML_RESTRICT qw) { - float max = 0; - float amax = 0; - for (int i = 0; i < n; ++i) { - float ax = fabsf(x[i]); - if (ax > amax) { amax = ax; max = x[i]; } - } - if (amax < GROUP_MAX_EPS) { // all zero - for (int i = 0; i < n; ++i) { - L[i] = 0; - } - return 0.f; - } - float iscale = -nmax / max; - if (rmse_type == 0) { - for (int i = 0; i < n; ++i) { - int l = nearest_int(iscale * x[i]); - L[i] = nmax + MAX(-nmax, MIN(nmax-1, l)); - } - return 1/iscale; - } - bool return_early = false; - if (rmse_type < 0) { - rmse_type = -rmse_type; - return_early = true; - } - float sumlx = 0; - float suml2 = 0; -#ifdef HAVE_BUGGY_APPLE_LINKER - // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7 - for (volatile int i = 0; i < n; ++i) { -#else - for (int i = 0; i < n; ++i) { -#endif - int l = nearest_int(iscale * x[i]); - l = MAX(-nmax, MIN(nmax-1, l)); - L[i] = l + nmax; - float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i])); - sumlx += w*x[i]*l; - suml2 += w*l*l; - } - float scale = suml2 ? sumlx/suml2 : 0.0f; - if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale; - float best = scale * sumlx; - for (int is = -9; is <= 9; ++is) { - if (is == 0) { - continue; - } - iscale = -(nmax + 0.1f*is) / max; - sumlx = suml2 = 0; - for (int i = 0; i < n; ++i) { - int l = nearest_int(iscale * x[i]); - l = MAX(-nmax, MIN(nmax-1, l)); - float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i])); - sumlx += w*x[i]*l; - suml2 += w*l*l; - } - if (suml2 > 0 && sumlx*sumlx > best*suml2) { - for (int i = 0; i < n; ++i) { - int l = nearest_int(iscale * x[i]); - L[i] = nmax + MAX(-nmax, MIN(nmax-1, l)); - } - scale = sumlx/suml2; best = scale*sumlx; - } - } - return scale; -} - -static float make_q3_quants(int n, int nmax, const float * GGML_RESTRICT x, int8_t * GGML_RESTRICT L, bool do_rmse) { - float max = 0; - float amax = 0; - for (int i = 0; i < n; ++i) { - float ax = fabsf(x[i]); - if (ax > amax) { amax = ax; max = x[i]; } - } - if (amax < GROUP_MAX_EPS) { // all zero - for (int i = 0; i < n; ++i) { L[i] = 0; } - return 0.f; - } - float iscale = -nmax / max; - if (do_rmse) { - float sumlx = 0; - float suml2 = 0; - for (int i = 0; i < n; ++i) { - int l = nearest_int(iscale * x[i]); - l = MAX(-nmax, MIN(nmax-1, l)); - L[i] = l; - float w = x[i]*x[i]; - sumlx += w*x[i]*l; - suml2 += w*l*l; - } - for (int itry = 0; itry < 5; ++itry) { - int n_changed = 0; - for (int i = 0; i < n; ++i) { - float w = x[i]*x[i]; - float slx = sumlx - w*x[i]*L[i]; - if (slx > 0) { - float sl2 = suml2 - w*L[i]*L[i]; - int new_l = nearest_int(x[i] * sl2 / slx); - new_l = MAX(-nmax, MIN(nmax-1, new_l)); - if (new_l != L[i]) { - slx += w*x[i]*new_l; - sl2 += w*new_l*new_l; - if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) { - L[i] = new_l; sumlx = slx; suml2 = sl2; - ++n_changed; - } - } - } - } - if (!n_changed) { - break; - } - } - for (int i = 0; i < n; ++i) { - L[i] += nmax; - } - return sumlx / suml2; - } - for (int i = 0; i < n; ++i) { - int l = nearest_int(iscale * x[i]); - l = MAX(-nmax, MIN(nmax-1, l)); - L[i] = l + nmax; - } - return 1/iscale; -} - -static float make_qkx1_quants(int n, int nmax, const float * GGML_RESTRICT x, uint8_t * GGML_RESTRICT L, float * GGML_RESTRICT the_min, - int ntry, float alpha) { - float min = x[0]; - float max = x[0]; - for (int i = 1; i < n; ++i) { - if (x[i] < min) min = x[i]; - if (x[i] > max) max = x[i]; - } - if (max == min) { - for (int i = 0; i < n; ++i) L[i] = 0; - *the_min = 0; - return 0.f; - } - if (min > 0) min = 0; - float iscale = nmax/(max - min); - float scale = 1/iscale; - for (int itry = 0; itry < ntry; ++itry) { - float sumlx = 0; int suml2 = 0; - bool did_change = false; - for (int i = 0; i < n; ++i) { - int l = nearest_int(iscale*(x[i] - min)); - l = MAX(0, MIN(nmax, l)); - if (l != L[i]) { - L[i] = l; - did_change = true; - } - sumlx += (x[i] - min)*l; - suml2 += l*l; - } - scale = sumlx/suml2; - float sum = 0; - for (int i = 0; i < n; ++i) { - sum += x[i] - scale*L[i]; - } - min = alpha*min + (1 - alpha)*sum/n; - if (min > 0) min = 0; - iscale = 1/scale; - if (!did_change) break; - } - *the_min = -min; - return scale; -} - -static float make_qkx2_quants(int n, int nmax, const float * GGML_RESTRICT x, const float * GGML_RESTRICT weights, - uint8_t * GGML_RESTRICT L, float * GGML_RESTRICT the_min, uint8_t * GGML_RESTRICT Laux, - float rmin, float rdelta, int nstep, bool use_mad) { - float min = x[0]; - float max = x[0]; - float sum_w = weights[0]; - float sum_x = sum_w * x[0]; -#ifdef HAVE_BUGGY_APPLE_LINKER - // use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7 - for (volatile int i = 1; i < n; ++i) { -#else - for (int i = 1; i < n; ++i) { -#endif - if (x[i] < min) min = x[i]; - if (x[i] > max) max = x[i]; - float w = weights[i]; - sum_w += w; - sum_x += w * x[i]; - } - if (min > 0) min = 0; - if (max == min) { - for (int i = 0; i < n; ++i) L[i] = 0; - *the_min = -min; - return 0.f; - } - float iscale = nmax/(max - min); - float scale = 1/iscale; - float best_mad = 0; - for (int i = 0; i < n; ++i) { - int l = nearest_int(iscale*(x[i] - min)); - L[i] = MAX(0, MIN(nmax, l)); - float diff = scale * L[i] + min - x[i]; - diff = use_mad ? fabsf(diff) : diff * diff; - float w = weights[i]; - best_mad += w * diff; - } - if (nstep < 1) { - *the_min = -min; - return scale; - } - for (int is = 0; is <= nstep; ++is) { - iscale = (rmin + rdelta*is + nmax)/(max - min); - float sum_l = 0, sum_l2 = 0, sum_xl = 0; - for (int i = 0; i < n; ++i) { - int l = nearest_int(iscale*(x[i] - min)); - l = MAX(0, MIN(nmax, l)); - Laux[i] = l; - float w = weights[i]; - sum_l += w*l; - sum_l2 += w*l*l; - sum_xl += w*l*x[i]; - } - float D = sum_w * sum_l2 - sum_l * sum_l; - if (D > 0) { - float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D; - float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D; - if (this_min > 0) { - this_min = 0; - this_scale = sum_xl / sum_l2; - } - float mad = 0; - for (int i = 0; i < n; ++i) { - float diff = this_scale * Laux[i] + this_min - x[i]; - diff = use_mad ? fabsf(diff) : diff * diff; - float w = weights[i]; - mad += w * diff; - } - if (mad < best_mad) { - for (int i = 0; i < n; ++i) { - L[i] = Laux[i]; - } - best_mad = mad; - scale = this_scale; - min = this_min; - } - } - } - *the_min = -min; - return scale; -} - -static inline void get_scale_min_k4(int j, const uint8_t * GGML_RESTRICT q, uint8_t * GGML_RESTRICT d, uint8_t * GGML_RESTRICT m) { - if (j < 4) { - *d = q[j] & 63; *m = q[j + 4] & 63; - } else { - *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); - *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); - } -} - -//========================- 2-bit (de)-quantization - -void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { - quantize_row_q2_K_ref(x, vy, k); -} - -//========================= 3-bit (de)-quantization - -void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { - quantize_row_q3_K_ref(x, vy, k); -} - -// ====================== 4-bit (de)-quantization - -void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { - assert(k % QK_K == 0); - block_q4_K * GGML_RESTRICT y = vy; - quantize_row_q4_K_ref(x, y, k); -} - -// ====================== 5-bit (de)-quantization - -void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { - assert(k % QK_K == 0); - block_q5_K * GGML_RESTRICT y = vy; - quantize_row_q5_K_ref(x, y, k); -} - -// ====================== 6-bit (de)-quantization - -void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { - assert(k % QK_K == 0); - block_q6_K * GGML_RESTRICT y = vy; - quantize_row_q6_K_ref(x, y, k); -} - -// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs) - -void quantize_row_tq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { - assert(k % QK_K == 0); - block_tq1_0 * GGML_RESTRICT y = vy; - quantize_row_tq1_0_ref(x, y, k); -} - -void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { - assert(k % QK_K == 0); - block_tq2_0 * GGML_RESTRICT y = vy; - quantize_row_tq2_0_ref(x, y, k); -} - -static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; - -//===================================== Q8_K ============================================== - -void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { -#ifdef __wasm_simd128__ - assert(k % QK_K == 0); - const int64_t nb = k / QK_K; - block_q8_K * GGML_RESTRICT yc = y; // Cast to proper type - - for (int i = 0; i < nb; i++) { - const float * x_block = x + i * QK_K; - - v128_t min_vec = wasm_v128_load(x_block); - v128_t max_vec = min_vec; - - for (int j = 4; j < QK_K; j += 4) { - v128_t x_vec = wasm_v128_load(x_block + j); - max_vec = wasm_f32x4_pmax(max_vec, x_vec); - min_vec = wasm_f32x4_pmin(min_vec, x_vec); - } - max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 2, 3, 0, 1)); - max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 1, 0, 3, 2)); - min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 2, 3, 0, 1)); - min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 1, 0, 3, 2)); - float max = wasm_f32x4_extract_lane(max_vec, 0); - float min = wasm_f32x4_extract_lane(min_vec, 0); - float amax = -min > max ? min : max; - - if (amax == 0.0f) { - yc[i].d = 0.0f; - const v128_t zero = wasm_i8x16_splat(0); - for (int j = 0; j < QK_K; j += 16) { - wasm_v128_store(yc[i].qs + j, zero); - } - continue; - } - - const float iscale = -127.0f / amax; - const v128_t scale_vec = wasm_f32x4_splat(iscale); - - // Process 16 elements per iteration - for (int j = 0, jb = 0; j < QK_K; j += 16, jb++) { - // Load and quantize 16 floats - v128_t x0 = wasm_v128_load(x_block + j); - v128_t x1 = wasm_v128_load(x_block + j + 4); - v128_t x2 = wasm_v128_load(x_block + j + 8); - v128_t x3 = wasm_v128_load(x_block + j + 12); - - v128_t q0 = wasm_f32x4_nearest(wasm_f32x4_mul(x0, scale_vec)); - v128_t q1 = wasm_f32x4_nearest(wasm_f32x4_mul(x1, scale_vec)); - v128_t q2 = wasm_f32x4_nearest(wasm_f32x4_mul(x2, scale_vec)); - v128_t q3 = wasm_f32x4_nearest(wasm_f32x4_mul(x3, scale_vec)); - - // Convert to i32 with saturation - v128_t i0 = wasm_i32x4_trunc_sat_f32x4(q0); - v128_t i1 = wasm_i32x4_trunc_sat_f32x4(q1); - v128_t i2 = wasm_i32x4_trunc_sat_f32x4(q2); - v128_t i3 = wasm_i32x4_trunc_sat_f32x4(q3); - - // Pack into 16 i8 values - v128_t i8 = wasm_i8x16_narrow_i16x8( - wasm_i16x8_narrow_i32x4(i0, i1), - wasm_i16x8_narrow_i32x4(i2, i3) - ); - wasm_v128_store(yc[i].qs + j, i8); - - // Calculate bsums using SIMD - v128_t sum16 = wasm_i16x8_add( - wasm_i16x8_extend_low_i8x16(i8), - wasm_i16x8_extend_high_i8x16(i8) - ); - v128_t sum32 = wasm_i32x4_add( - wasm_i32x4_extend_low_i16x8(sum16), - wasm_i32x4_extend_high_i16x8(sum16) - ); - sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 2, 3, 0, 1)); - sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 1, 0, 3, 2)); - yc[i].bsums[jb] = wasm_i32x4_extract_lane(sum32, 0); - } - - yc[i].d = 1.0f / iscale; - } -#else - quantize_row_q8_K_ref(x, y, k); -#endif -} - -//===================================== Dot products ================================= - -// -// Helper functions -// -#if __AVX__ || __AVX2__ || __AVX512F__ - -// shuffles to pick the required scales in dot products -static inline __m256i get_scale_shuffle_q3k(int i) { - static const uint8_t k_shuffle[128] = { - 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, - 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, - 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, - 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, - }; - return _mm256_loadu_si256((const __m256i*)k_shuffle + i); -} -static inline __m256i get_scale_shuffle_k4(int i) { - static const uint8_t k_shuffle[256] = { - 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, - 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, - 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, - 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, - 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, - 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, - 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, - 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15 - }; - return _mm256_loadu_si256((const __m256i*)k_shuffle + i); -} -static inline __m128i get_scale_shuffle(int i) { - static const uint8_t k_shuffle[128] = { - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, - 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, - 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, - 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, - 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, - 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11, - 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13, - 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15 - }; - return _mm_loadu_si128((const __m128i*)k_shuffle + i); -} -#elif defined(__loongarch_asx) -// shuffles to pick the required scales in dot products -static inline __m256i get_scale_shuffle_q3k(int i) { - static const uint8_t k_shuffle[128] = { - 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, - 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, - 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, - 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, - }; - return __lasx_xvld((const __m256i*)k_shuffle + i, 0); -} -static inline __m256i get_scale_shuffle_k4(int i) { - static const uint8_t k_shuffle[256] = { - 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, - 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, - 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, - 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, - 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, - 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, - 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, - 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15 - }; - return __lasx_xvld((const __m256i*)k_shuffle + i, 0); -} -static inline __m128i get_scale_shuffle(int i) { - static const uint8_t k_shuffle[128] = { - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, - 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, - 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, - 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, - 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, - 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11, - 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13, - 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15 - }; - return __lsx_vld((const __m128i*)k_shuffle + i, 0); -} -#endif - -void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - const int qk = QK8_0; - const int nb = n / qk; - - assert(n % qk == 0); -#if defined(__ARM_FEATURE_MATMUL_INT8) - assert((nrc == 2) || (nrc == 1)); -#else - assert(nrc == 1); -#endif - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q4_0 * GGML_RESTRICT x = vx; - const block_q8_0 * GGML_RESTRICT y = vy; - -#if defined(__ARM_FEATURE_MATMUL_INT8) - if (nrc == 2) { - const block_q4_0 * GGML_RESTRICT vx0 = vx; - const block_q4_0 * GGML_RESTRICT vx1 = (const block_q4_0 *) ((const uint8_t*)vx + bx); - const block_q8_0 * GGML_RESTRICT vy0 = vy; - const block_q8_0 * GGML_RESTRICT vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by); - - float32x4_t sumv0 = vdupq_n_f32(0.0f); - - for (int i = 0; i < nb; i++) { - const block_q4_0 * GGML_RESTRICT b_x0 = &vx0[i]; - const block_q4_0 * GGML_RESTRICT b_x1 = &vx1[i]; - const block_q8_0 * GGML_RESTRICT b_y0 = &vy0[i]; - const block_q8_0 * GGML_RESTRICT b_y1 = &vy1[i]; - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - const int8x16_t s8b = vdupq_n_s8(0x8); - - const uint8x16_t v0_0 = vld1q_u8(b_x0->qs); - const uint8x16_t v0_1 = vld1q_u8(b_x1->qs); - - // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // sub 8 - const int8x16_t x0_l = vsubq_s8(v0_0l, s8b); - const int8x16_t x0_h = vsubq_s8(v0_0h, s8b); - const int8x16_t x1_l = vsubq_s8(v0_1l, s8b); - const int8x16_t x1_h = vsubq_s8(v0_1h, s8b); - - // load y - const int8x16_t y0_l = vld1q_s8(b_y0->qs); - const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); - const int8x16_t y1_l = vld1q_s8(b_y1->qs); - const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); - - float32_t _scale[4] = { - GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d), - GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d), - GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d), - GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d) - }; - float32x4_t scale = vld1q_f32(_scale); - - int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); - int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); - - int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); - int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); - - int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); - int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); - - int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); - int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); - - sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), - l1, r1)), l2, r2)), l3, r3))), scale); - } - - float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2); - float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); - - vst1_f32(s, vget_low_f32 (sumv2)); - vst1_f32(s + bs, vget_high_f32(sumv2)); - - return; - } -#endif - - int ib = 0; - float sumf = 0; - -#if defined(__ARM_FEATURE_SVE) - svfloat32_t sumv0 = svdup_n_f32(0.0f); - svfloat32_t sumv1 = svdup_n_f32(0.0f); - - const int vector_length = ggml_cpu_get_sve_cnt()*8; - - // VLA Implementation using switch case - switch (vector_length) { - case 128: - { - // predicate for activating higher lanes for 4 float32 elements - const svbool_t ph4 = svptrue_pat_b32(SV_VL4); - - for (; ib + 1 < nb; ib += 2) { - const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0]; - const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1]; - const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; - const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; - - // load x - const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); - const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); - - // 4-bit -> 8-bit - const svint8_t qx0l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx0r, 0x0F)); - const svint8_t qx0h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx0r, 0x04)); - const svint8_t qx1l = svreinterpret_s8_u8(svand_n_u8_m(svptrue_b8(), qx1r, 0x0F)); - const svint8_t qx1h = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), qx1r, 0x04)); - - // sub 8 - const svint8_t qx0ls = svsub_n_s8_x(svptrue_b8(), qx0h, 8); - const svint8_t qx0hs = svsub_n_s8_x(svptrue_b8(), qx0l, 8); - const svint8_t qx1ls = svsub_n_s8_x(svptrue_b8(), qx1h, 8); - const svint8_t qx1hs = svsub_n_s8_x(svptrue_b8(), qx1l, 8); - - // load y - const svint8_t qy0h = svld1_s8(svptrue_b8(), y0->qs); - const svint8_t qy0l = svld1_s8(svptrue_b8(), y0->qs + 16); - const svint8_t qy1h = svld1_s8(svptrue_b8(), y1->qs); - const svint8_t qy1l = svld1_s8(svptrue_b8(), y1->qs + 16); - - // dot product - sumv0 = svmla_n_f32_x(ph4, sumv0, svcvt_f32_s32_x(ph4, svadd_x(ph4, - svdot_s32(svdup_n_s32(0), qx0ls, qy0l), - svdot_s32(svdup_n_s32(0), qx0hs, qy0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(ph4, sumv1, svcvt_f32_s32_x(ph4, svadd_x(ph4, - svdot_s32(svdup_n_s32(0), qx1ls, qy1l), - svdot_s32(svdup_n_s32(0), qx1hs, qy1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - } - - sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); - } break; - case 256: - { - // predicate for activating higher lanes for 16 int8 elements - const svbool_t ph16 = svptrue_pat_b8(SV_VL16); - // predicate for activating lower lanes for 16 int8 elements - const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16); - - for (; ib + 1 < nb; ib += 2) { - const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0]; - const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1]; - const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; - const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; - - // load x - const svuint8_t qx0r = svld1rq_u8(svptrue_b8(), x0->qs); - const svuint8_t qx1r = svld1rq_u8(svptrue_b8(), x1->qs); - - // 4-bit -> 8-bit - const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04)); - const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04)); - - // sub 8 - const svint8_t qx0s = svsub_n_s8_x(svptrue_b8(), qx0, 8); - const svint8_t qx1s = svsub_n_s8_x(svptrue_b8(), qx1, 8); - - // load y - const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); - const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); - - // dot product - sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), - svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), - svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - } - - sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); - } break; - case 512: - { - // predicate for activating higher lanes for 32 int8 elements - const svbool_t ph32 = svptrue_pat_b8(SV_VL32); - - // predicate for activating higher lanes for 16 int8 elements - const svbool_t ph16 = svptrue_pat_b8(SV_VL16); - // predicate for activating lower lanes for 16 int8 elements from first 32 int8 activated lanes - const svbool_t pl16 = svnot_b_z(ph32, ph16); - - for (; ib + 1 < nb; ib += 2) { - const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0]; - const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1]; - const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; - const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; - - // load x - const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs); - const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs); - - // 4-bit -> 8-bit - const svint8_t qx0 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04)); - const svint8_t qx1 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04)); - - // sub 8 - const svint8_t qx0s = svsub_n_s8_x(ph32, qx0, 8); - const svint8_t qx1s = svsub_n_s8_x(ph32, qx1, 8); - - // load y - const svint8_t qy0 = svld1_s8(ph32, y0->qs); - const svint8_t qy1 = svld1_s8(ph32, y1->qs); - - // dot product - sumv0 = svmla_n_f32_x(ph32, sumv0, svcvt_f32_s32_x(ph32, - svdot_s32(svdup_n_s32(0), qx0s, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(ph32, sumv1, svcvt_f32_s32_x(ph32, - svdot_s32(svdup_n_s32(0), qx1s, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - } - - sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1)); - } break; - default: - assert(false && "Unsupported vector length"); - break; - } - -#elif defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - for (; ib + 1 < nb; ib += 2) { - const block_q4_0 * GGML_RESTRICT x0 = &x[ib + 0]; - const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1]; - const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; - const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - const int8x16_t s8b = vdupq_n_s8(0x8); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - - // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // sub 8 - const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); - const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); - const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); - const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); - - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - - // dot product into int32x4_t - const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h); - const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - } - - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined __wasm_simd128__ - v128_t sumv = wasm_f32x4_splat(0.0f); - - const v128_t m4b = wasm_i8x16_splat(0x0F); - const v128_t s8b = wasm_i8x16_splat(0x8); - - for (; ib + 1 < nb; ib += 2) { - const block_q4_0 * GGML_RESTRICT x0 = &x[ib]; - const block_q4_0 * GGML_RESTRICT x1 = &x[ib + 1]; - const block_q8_0 * GGML_RESTRICT y0 = &y[ib]; - const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; - - // Load and process x0 - v128_t v0_0 = wasm_v128_load(x0->qs); - v128_t v0_0l = wasm_v128_and(v0_0, m4b); - v128_t v0_0h = wasm_u8x16_shr(v0_0, 4); - v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b); - v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b); - - // Load y0 vectors - v128_t y0_l = wasm_v128_load(y0->qs); - v128_t y0_h = wasm_v128_load(y0->qs + 16); - - // Extend to i16x8 and compute dot products - v128_t dx0l = wasm_i16x8_extend_low_i8x16(v0_0ls); - v128_t dx0h = wasm_i16x8_extend_high_i8x16(v0_0ls); - v128_t dx0hl = wasm_i16x8_extend_low_i8x16(v0_0hs); - v128_t dx0hh = wasm_i16x8_extend_high_i8x16(v0_0hs); - - v128_t dy0ll = wasm_i16x8_extend_low_i8x16(y0_l); - v128_t dy0lh = wasm_i16x8_extend_high_i8x16(y0_l); - v128_t dy0hl = wasm_i16x8_extend_low_i8x16(y0_h); - v128_t dy0hh = wasm_i16x8_extend_high_i8x16(y0_h); - - v128_t dp0 = wasm_i32x4_add( - wasm_i32x4_add( - wasm_i32x4_dot_i16x8(dx0l, dy0ll), - wasm_i32x4_dot_i16x8(dx0h, dy0lh) - ), - wasm_i32x4_add( - wasm_i32x4_dot_i16x8(dx0hl, dy0hl), - wasm_i32x4_dot_i16x8(dx0hh, dy0hh) - ) - ); - - // Load and process x1 - v128_t v0_1 = wasm_v128_load(x1->qs); - v128_t v0_1l = wasm_v128_and(v0_1, m4b); - v128_t v0_1h = wasm_u8x16_shr(v0_1, 4); - v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b); - v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b); - - // Load y1 vectors - v128_t y1_l = wasm_v128_load(y1->qs); - v128_t y1_h = wasm_v128_load(y1->qs + 16); - - // Extend to i16x8 and compute dot products - v128_t dx1l = wasm_i16x8_extend_low_i8x16(v0_1ls); - v128_t dx1h = wasm_i16x8_extend_high_i8x16(v0_1ls); - v128_t dx1hl = wasm_i16x8_extend_low_i8x16(v0_1hs); - v128_t dx1hh = wasm_i16x8_extend_high_i8x16(v0_1hs); - - v128_t dy1ll = wasm_i16x8_extend_low_i8x16(y1_l); - v128_t dy1lh = wasm_i16x8_extend_high_i8x16(y1_l); - v128_t dy1hl = wasm_i16x8_extend_low_i8x16(y1_h); - v128_t dy1hh = wasm_i16x8_extend_high_i8x16(y1_h); - - v128_t dp1 = wasm_i32x4_add( - wasm_i32x4_add( - wasm_i32x4_dot_i16x8(dx1l, dy1ll), - wasm_i32x4_dot_i16x8(dx1h, dy1lh) - ), - wasm_i32x4_add( - wasm_i32x4_dot_i16x8(dx1hl, dy1hl), - wasm_i32x4_dot_i16x8(dx1hh, dy1hh) - ) - ); - - // Accumulate results with scaling - float scale0 = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d); - float scale1 = GGML_FP16_TO_FP32(x1->d) * GGML_FP16_TO_FP32(y1->d); - - sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp0), wasm_f32x4_splat(scale0))); - sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp1), wasm_f32x4_splat(scale1))); - } - - sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + - wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); -#elif defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - // Main loop - for (; ib < nb; ++ib) { - /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) ); - - __m256i qx = bytes_from_nibbles_32(x[ib].qs); - - // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. - const __m256i off = _mm256_set1_epi8( 8 ); - qx = _mm256_sub_epi8( qx, off ); - - __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); - - const __m256 q = mul_sum_i8_pairs_float(qx, qy); - - /* Multiply q with scale and accumulate */ - acc = _mm256_fmadd_ps( d, q, acc ); - } - - sumf = hsum_float_8(acc); -#elif defined(__AVX__) - __m256 accum = _mm256_setzero_ps(); - for (; ib + 1 < nb; ib += 2) { - const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs); - const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); - const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs); - const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1); - const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); - const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1); - - const __m128i q4b_1_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_1), _mm_set1_epi8(8)); - const __m128i q4b_1_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_1, 4)), _mm_set1_epi8(8)); - const __m128i q4b_2_0 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), q4bits_2), _mm_set1_epi8(8)); - const __m128i q4b_2_1 = _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(q4bits_2, 4)), _mm_set1_epi8(8)); - - const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0); - const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1); - const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0); - const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1); - const __m128i p_1 = _mm_add_epi16(p16_1_0, p16_1_1); - const __m128i p_2 = _mm_add_epi16(p16_2_0, p16_2_1); - const __m256 p = sum_i16_pairs_float(p_2, p_1); - - const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d); - accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum); - } - - sumf = hsum_float_8(accum); -#elif defined(__SSSE3__) - // set constants - const __m128i lowMask = _mm_set1_epi8(0xF); - const __m128i off = _mm_set1_epi8(8); - - // Initialize accumulator with zeros - __m128 acc_0 = _mm_setzero_ps(); - __m128 acc_1 = _mm_setzero_ps(); - __m128 acc_2 = _mm_setzero_ps(); - __m128 acc_3 = _mm_setzero_ps(); - - for (; ib + 1 < nb; ib += 2) { - _mm_prefetch(&x[ib] + sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[ib] + sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) ); - - const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[ib].qs); - - __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); - __m128i by_0 = _mm_loadu_si128((const __m128i *)y[ib].qs); - bx_0 = _mm_sub_epi8(bx_0, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); - - __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); - __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[ib].qs + 16)); - bx_1 = _mm_sub_epi8(bx_1, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); - - _mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) ); - - const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); - - __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); - __m128i by_2 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); - bx_2 = _mm_sub_epi8(bx_2, off); - const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); - - __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); - __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[ib + 1].qs + 16)); - bx_3 = _mm_sub_epi8(bx_3, off); - const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); - - // Convert int32_t to float - __m128 p0 = _mm_cvtepi32_ps(i32_0); - __m128 p1 = _mm_cvtepi32_ps(i32_1); - __m128 p2 = _mm_cvtepi32_ps(i32_2); - __m128 p3 = _mm_cvtepi32_ps(i32_3); - - // Apply the scale - __m128 p0_d = _mm_mul_ps( d_0_1, p0 ); - __m128 p1_d = _mm_mul_ps( d_0_1, p1 ); - __m128 p2_d = _mm_mul_ps( d_2_3, p2 ); - __m128 p3_d = _mm_mul_ps( d_2_3, p3 ); - - // Acummulate - acc_0 = _mm_add_ps(p0_d, acc_0); - acc_1 = _mm_add_ps(p1_d, acc_1); - acc_2 = _mm_add_ps(p2_d, acc_2); - acc_3 = _mm_add_ps(p3_d, acc_3); - } - - sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); -#elif defined(__riscv_v_intrinsic) - size_t vl = qk / 2; - - for (; ib < nb; ++ib) { - // load elements - vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl); - - vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl); - vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl); - - // mask and store lower part of x, and then upper part - vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl); - vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); - - vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a); - vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l); - - // subtract offset - vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl); - vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl); - - vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); - vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl); - - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - - sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d); - } - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector signed int v0 = vec_splats((int32_t)0); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - const vector signed char v8 = vec_splats((signed char)0x8); - - vector float vsumf0 = vec_splats(0.0f); - -#pragma GCC unroll 8 - for (; ib < nb; ++ib) { - __builtin_prefetch(x[ib].qs, 0, 1); - __builtin_prefetch(y[ib].qs, 0, 1); - - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); - vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); - vector float vd = vec_mul(vxd, vyd); - - vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); - vector signed char q8y0 = vec_xl( 0, y[ib].qs); - vector signed char q8y1 = vec_xl(16, y[ib].qs); - - vector signed char q4x0 = vec_and(qxs, lowMask); - vector signed char q4x1 = vec_sr(qxs, v4); - - q4x0 = vec_sub(q4x0, v8); - q4x1 = vec_sub(q4x1, v8); - - vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1)); - - vector signed int vsumi0 = v0; - - vsumi0 = vec_sum4s(qv0, vsumi0); - vsumi0 = vec_sum4s(qv1, vsumi0); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - } - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - sumf = vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - // Initialize accumulator with zeros - __m256 acc = (__m256)__lasx_xvldi(0); - - // Main loop - for (; ib < nb; ++ib) { - /* Compute combined scale for the block */ - const __m256 d = __lasx_xvreplfr2vr_s( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) ); - - __m256i qx = bytes_from_nibbles_32(x[ib].qs); - - // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. - const __m256i off = __lasx_xvreplgr2vr_b( 8 ); - qx = __lasx_xvsub_b( qx, off ); - - __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); - - const __m256 q = mul_sum_i8_pairs_float(qx, qy); - - /* Multiply q with scale and accumulate */ - acc = __lasx_xvfmadd_s( d, q, acc ); - } - - sumf = hsum_float_8(acc); - -#elif defined(__loongarch_sx) - // set constants - const __m128i low_mask = __lsx_vreplgr2vr_b(0xF); - const __m128i off = __lsx_vreplgr2vr_b(8); - - // Initialize accumulator with zeros - __m128 acc_0 = (__m128)__lsx_vldi(0); - __m128 acc_1 = (__m128)__lsx_vldi(0); - __m128 acc_2 = (__m128)__lsx_vldi(0); - __m128 acc_3 = (__m128)__lsx_vldi(0); - - for (; ib + 1 < nb; ib += 2) { - - // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = (__m128)__lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) ); - - const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0); - - __m128i bx_0 = __lsx_vand_v(low_mask, tmp_0_1); - __m128i by_0 = __lsx_vld((const __m128i *)y[ib].qs, 0); - bx_0 = __lsx_vsub_b(bx_0, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); - - __m128i bx_1 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_0_1, 4)); - __m128i by_1 = __lsx_vld((const __m128i *)(y[ib].qs + 16), 0); - bx_1 = __lsx_vsub_b(bx_1, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); - - //_mm_prefetch(&x[ib] + 2 * sizeof(block_q4_0), _MM_HINT_T0); - //_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = (__m128)__lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) ); - - const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0); - - __m128i bx_2 = __lsx_vand_v(low_mask, tmp_2_3); - __m128i by_2 = __lsx_vld((const __m128i *)y[ib + 1].qs, 0); - bx_2 = __lsx_vsub_b(bx_2, off); - const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); - - __m128i bx_3 = __lsx_vand_v(low_mask, __lsx_vsrli_d(tmp_2_3, 4)); - __m128i by_3 = __lsx_vld((const __m128i *)(y[ib + 1].qs + 16), 0); - bx_3 = __lsx_vsub_b(bx_3, off); - const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); - - // Convert int32_t to float - __m128 p0 = __lsx_vffint_s_w(i32_0); - __m128 p1 = __lsx_vffint_s_w(i32_1); - __m128 p2 = __lsx_vffint_s_w(i32_2); - __m128 p3 = __lsx_vffint_s_w(i32_3); - - // Apply the scale - __m128 p0_d = __lsx_vfmul_s( d_0_1, p0 ); - __m128 p1_d = __lsx_vfmul_s( d_0_1, p1 ); - __m128 p2_d = __lsx_vfmul_s( d_2_3, p2 ); - __m128 p3_d = __lsx_vfmul_s( d_2_3, p3 ); - - // Acummulate - acc_0 = __lsx_vfadd_s(p0_d, acc_0); - acc_1 = __lsx_vfadd_s(p1_d, acc_1); - acc_2 = __lsx_vfadd_s(p2_d, acc_2); - acc_3 = __lsx_vfadd_s(p3_d, acc_3); - } - - sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); -#elif defined(__VXE__) || defined(__VXE2__) - __vector float acc = vec_splats(0.0f); - - const __vector uint8_t v_m = vec_splats((const uint8_t)0x0F); - const __vector int8_t v_s = vec_splats( (const int8_t)0x08); - - for (; ib < nb; ++ib) { - const __vector uint8_t v_x = vec_xl(0, x[ib].qs); - const __vector int8_t v_xl = (const __vector int8_t)(v_x & v_m); - const __vector int8_t v_xh = (const __vector int8_t)(v_x >> 4); - - const __vector int8_t v_xls = vec_sub(v_xl, v_s); - const __vector int8_t v_xhs = vec_sub(v_xh, v_s); - - const __vector int8_t v_yl = vec_xl(0 , y[ib].qs); - const __vector int8_t v_yh = vec_xl(QK8_0/2, y[ib].qs); - - const __vector int16_t v_xylso = vec_mulo(v_xls, v_yl); - const __vector int16_t v_xylse = vec_mule(v_xls, v_yl); - const __vector int16_t v_xyhso = vec_mulo(v_xhs, v_yh); - const __vector int16_t v_xyhse = vec_mule(v_xhs, v_yh); - - __vector int16_t v_xy_ = v_xylso + v_xylse + v_xyhso + v_xyhse; v_xy_ += vec_reve(v_xy_); - - const __vector float v_xy = vec_float(vec_unpackh(v_xy_)); - const __vector float v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); - - acc = vec_madd(v_xy, v_d, acc); - } - - sumf = acc[0] + acc[1] + acc[2] + acc[3]; -#endif - for (; ib < nb; ++ib) { - int sumi0 = 0; - int sumi1 = 0; - - for (int j = 0; j < qk/2; ++j) { - const int v0 = (x[ib].qs[j] & 0x0F) - 8; - const int v1 = (x[ib].qs[j] >> 4) - 8; - - sumi0 += (v0 * y[ib].qs[j]); - sumi1 += (v1 * y[ib].qs[j + qk/2]); - } - - int sumi = sumi0 + sumi1; - sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d); - } - - *s = sumf; -} - -void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - const int qk = QK8_1; - const int nb = n / qk; - - assert(n % qk == 0); -#if defined(__ARM_FEATURE_MATMUL_INT8) - assert((nrc == 2) || (nrc == 1)); -#else - assert(nrc == 1); -#endif - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q4_1 * GGML_RESTRICT x = vx; - const block_q8_1 * GGML_RESTRICT y = vy; - -#if defined(__ARM_FEATURE_MATMUL_INT8) - if (nrc == 2) { - const block_q4_1 * GGML_RESTRICT vx0 = vx; - const block_q4_1 * GGML_RESTRICT vx1 = (const block_q4_1 *) ((const uint8_t*)vx + bx); - const block_q8_1 * GGML_RESTRICT vy0 = vy; - const block_q8_1 * GGML_RESTRICT vy1 = (const block_q8_1 *) ((const uint8_t*)vy + by); - - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t summs0 = vdupq_n_f32(0.0f); - - for (int i = 0; i < nb; i++) { - const block_q4_1 * GGML_RESTRICT b_x0 = &vx0[i]; - const block_q4_1 * GGML_RESTRICT b_x1 = &vx1[i]; - const block_q8_1 * GGML_RESTRICT b_y0 = &vy0[i]; - const block_q8_1 * GGML_RESTRICT b_y1 = &vy1[i]; - - float32_t summs_t[4] = { - GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y0->s), - GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y0->s), - GGML_FP16_TO_FP32(b_x0->m) * GGML_FP16_TO_FP32(b_y1->s), - GGML_FP16_TO_FP32(b_x1->m) * GGML_FP16_TO_FP32(b_y1->s) - }; - summs0 = vaddq_f32(summs0, vld1q_f32(summs_t)); - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - - const uint8x16_t v0_0 = vld1q_u8(b_x0->qs); - const uint8x16_t v0_1 = vld1q_u8(b_x1->qs); - - // 4-bit -> 8-bit - const int8x16_t x0_l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t x0_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t x1_l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t x1_h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // load y - const int8x16_t y0_l = vld1q_s8(b_y0->qs); - const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); - const int8x16_t y1_l = vld1q_s8(b_y1->qs); - const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); - - // mmla into int32x4_t - float32_t _scale[4] = { - GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d), - GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d), - GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d), - GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d) - }; - float32x4_t scale = vld1q_f32(_scale); - - int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); - int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); - - int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); - int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); - - int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); - int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); - - int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); - int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); - sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), - l1, r1)), l2, r2)), l3, r3))), scale); - } - - float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2); - float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); - - sumv2 = vaddq_f32(sumv2, summs0); - - vst1_f32(s, vget_low_f32 (sumv2)); - vst1_f32(s + bs, vget_high_f32(sumv2)); - - return; - } -#endif - - int ib = 0; - float sumf = 0; - - // TODO: add WASM SIMD -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - float summs = 0; - - for (; ib + 1 < nb; ib += 2) { - const block_q4_1 * GGML_RESTRICT x0 = &x[ib + 0]; - const block_q4_1 * GGML_RESTRICT x1 = &x[ib + 1]; - const block_q8_1 * GGML_RESTRICT y0 = &y[ib + 0]; - const block_q8_1 * GGML_RESTRICT y1 = &y[ib + 1]; - - summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s) + GGML_FP16_TO_FP32(x1->m) * GGML_FP16_TO_FP32(y1->s); - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - - // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - - // dot product into int32x4_t - const int32x4_t p_0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h); - const int32x4_t p_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - } - - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; -#elif defined(__AVX2__) || defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - float summs = 0; - - // Main loop - for (; ib < nb; ++ib) { - const float d0 = GGML_FP16_TO_FP32(x[ib].d); - const float d1 = GGML_FP16_TO_FP32(y[ib].d); - - summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); - - const __m256 d0v = _mm256_set1_ps( d0 ); - const __m256 d1v = _mm256_set1_ps( d1 ); - - // Compute combined scales - const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); - - // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes - const __m256i qx = bytes_from_nibbles_32(x[ib].qs); - const __m256i qy = _mm256_loadu_si256( (const __m256i *)y[ib].qs ); - - const __m256 xy = mul_sum_us8_pairs_float(qx, qy); - - // Accumulate d0*d1*x*y -#if defined(__AVX2__) - acc = _mm256_fmadd_ps( d0d1, xy, acc ); -#else - acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc ); -#endif - } - - sumf = hsum_float_8(acc) + summs; -#elif defined(__riscv_v_intrinsic) - size_t vl = qk / 2; - - for (; ib < nb; ++ib) { - // load elements - vuint8m1_t tx = __riscv_vle8_v_u8m1(x[ib].qs, vl); - - vint8m1_t y0 = __riscv_vle8_v_i8m1(y[ib].qs, vl); - vint8m1_t y1 = __riscv_vle8_v_i8m1(y[ib].qs+16, vl); - - // mask and store lower part of x, and then upper part - vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl); - vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); - - vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a); - vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l); - - vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); - vint16m2_t vec_mul2 = __riscv_vwmacc_vv_i16m2(vec_mul1, v1, y1, vl); - - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); - - sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); - } - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector signed int v0 = vec_splats((int32_t)0); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - - vector float vsumf0 = vec_splats(0.0f); - -#pragma GCC unroll 4 - for (; ib < nb; ++ib) { - __builtin_prefetch(x[ib].qs, 0, 1); - __builtin_prefetch(y[ib].qs, 0, 1); - - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); - vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); - vector float vd = vec_mul(vxd, vyd); - - vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[ib].m)); - vector float vys = {GGML_FP16_TO_FP32(y[ib].s), 0.0f, 0.0f, 0.0f}; - vsumf0 = vec_madd(vxmin, vys, vsumf0); - - vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); - vector signed char q8y0 = vec_xl( 0, y[ib].qs); - vector signed char q8y1 = vec_xl(16, y[ib].qs); - - vector unsigned char q4x0 = (vector unsigned char)vec_and(qxs, lowMask); - vector unsigned char q4x1 = (vector unsigned char)vec_sr(qxs, v4); - - vector signed int vsumi0 = v0; - - vsumi0 = vec_msum(q8y0, q4x0, vsumi0); - vsumi0 = vec_msum(q8y1, q4x1, vsumi0); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - } - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - sumf = vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - // Initialize accumulator with zeros - __m256 acc = (__m256)__lasx_xvldi(0); - - float summs = 0; - - // Main loop - for (; ib < nb; ++ib) { - const float d0 = GGML_FP16_TO_FP32(x[ib].d); - const float d1 = GGML_FP16_TO_FP32(y[ib].d); - - summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); - - const __m256 d0v = __lasx_xvreplfr2vr_s( d0 ); - const __m256 d1v = __lasx_xvreplfr2vr_s( d1 ); - - // Compute combined scales - const __m256 d0d1 = __lasx_xvfmul_s( d0v, d1v ); - - // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes - const __m256i qx = bytes_from_nibbles_32(x[ib].qs); - const __m256i qy = __lasx_xvld( (const __m256i *)y[ib].qs, 0); - - const __m256 xy = mul_sum_us8_pairs_float(qx, qy); - - // Accumulate d0*d1*x*y - acc = __lasx_xvfmadd_s( d0d1, xy, acc ); - } - - sumf = hsum_float_8(acc) + summs; -#elif defined(__VXE__) || defined(__VXE2__) - float summs = 0; - float32x4_t acc = vec_splats(0.0f); - - const uint8x16_t v_m = vec_splat_u8(0x0F); - -#pragma GCC unroll 4 - for (; ib < nb; ++ib) { - __builtin_prefetch(x[ib].qs, 0, 1); - __builtin_prefetch(y[ib].qs, 0, 1); - - summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); - - const uint8x16_t v_x = vec_xl(0, x[ib].qs); - const int8x16_t v_xl = (const int8x16_t)(v_x & v_m); - const int8x16_t v_xh = (const int8x16_t)(v_x >> 4); - - const int8x16_t v_yl = vec_xl(0 , y[ib].qs); - const int8x16_t v_yh = vec_xl(QK8_1/2, y[ib].qs); - - const int32x4_t v_xy_ = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh); - const float32x4_t v_xy = vec_float(v_xy_); - - const float32x4_t v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); - - acc = vec_madd(v_xy, v_d, acc); - } - - sumf = acc[0] + acc[1] + acc[2] + acc[3] + summs; -#endif - for (; ib < nb; ++ib) { - int sumi0 = 0; - int sumi1 = 0; - - for (int j = 0; j < qk/2; ++j) { - const int v0 = (x[ib].qs[j] & 0x0F); - const int v1 = (x[ib].qs[j] >> 4); - - sumi0 += (v0 * y[ib].qs[j]); - sumi1 += (v1 * y[ib].qs[j + qk/2]); - } - - int sumi = sumi0 + sumi1; - sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); - } - - *s = sumf; -} - -void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - const int qk = QK8_0; - const int nb = n / qk; - - int ib = 0; - float sumf = 0; - - assert(n % qk == 0); - assert(qk == QK5_0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q5_0 * GGML_RESTRICT x = vx; - const block_q8_0 * GGML_RESTRICT y = vy; - -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - uint32_t qh0; - uint32_t qh1; - - uint64_t tmp0[4]; - uint64_t tmp1[4]; - - for (; ib + 1 < nb; ib += 2) { - const block_q5_0 * GGML_RESTRICT x0 = &x[ib]; - const block_q5_0 * GGML_RESTRICT x1 = &x[ib + 1]; - const block_q8_0 * GGML_RESTRICT y0 = &y[ib]; - const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - - // extract the 5th bit via lookup table ((!b) << 4) - memcpy(&qh0, x0->qh, sizeof(qh0)); - memcpy(&qh1, x1->qh, sizeof(qh1)); - - tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF]; - tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF]; - tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF]; - tmp0[3] = table_b2b_1[(qh0 >> 24) ]; - - tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF]; - tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF]; - tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF]; - tmp1[3] = table_b2b_1[(qh1 >> 24) ]; - - const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); - const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); - const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); - const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - - // 4-bit -> 8-bit - int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) - const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0); - const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0); - const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1); - const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1); - - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), - ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), - ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - } - - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined __wasm_simd128__ - v128_t sumv = wasm_f32x4_splat(0.0f); - - uint32_t qh_; - uint64_t tmp[4]; - - // TODO: check if unrolling this is better - for (; ib < nb; ++ib) { - const block_q5_0 * GGML_RESTRICT x0 = &x[ib]; - const block_q8_0 * GGML_RESTRICT y0 = &y[ib]; - - const v128_t m4b = wasm_i8x16_splat(0x0F); - - // extract the 5th bit - memcpy(&qh_, x0->qh, sizeof(qh_)); - - tmp[0] = table_b2b_1[(qh_ >> 0) & 0xFF]; - tmp[1] = table_b2b_1[(qh_ >> 8) & 0xFF]; - tmp[2] = table_b2b_1[(qh_ >> 16) & 0xFF]; - tmp[3] = table_b2b_1[(qh_ >> 24) ]; - - const v128_t qhl = wasm_v128_load(tmp + 0); - const v128_t qhh = wasm_v128_load(tmp + 2); - - const v128_t v0 = wasm_v128_load(x0->qs); - - // 4-bit -> 8-bit - const v128_t v0l = wasm_v128_and (v0, m4b); - const v128_t v0h = wasm_u8x16_shr(v0, 4); - - // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) - const v128_t v0lf = wasm_i8x16_sub(v0l, qhl); - const v128_t v0hf = wasm_i8x16_sub(v0h, qhh); - - // load y - const v128_t v1l = wasm_v128_load(y0->qs); - const v128_t v1h = wasm_v128_load(y0->qs + 16); - - // int8x16 -> int16x8 - const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); - const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); - const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); - const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); - - const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); - const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); - const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); - const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); - - // dot product - sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4( - wasm_i32x4_add( - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), - wasm_i32x4_dot_i16x8(v0lfh, v1lh)), - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), - wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), - wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d)))); - } - - sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + - wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); -#elif defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - // Main loop - for (; ib < nb; ++ib) { - /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); - - __m256i qx = bytes_from_nibbles_32(x[ib].qs); - __m256i bxhi = bytes_from_bits_32(x[ib].qh); - bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0)); - qx = _mm256_or_si256(qx, bxhi); - - __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); - - const __m256 q = mul_sum_i8_pairs_float(qx, qy); - - /* Multiply q with scale and accumulate */ - acc = _mm256_fmadd_ps(d, q, acc); - } - - sumf = hsum_float_8(acc); -#elif defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - __m128i mask = _mm_set1_epi8((char)0xF0); - - // Main loop - for (; ib < nb; ++ib) { - /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); - - __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs); - const __m256i bxhi = bytes_from_bits_32(x[ib].qh); - __m128i bxhil = _mm256_castsi256_si128(bxhi); - __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); - bxhil = _mm_andnot_si128(bxhil, mask); - bxhih = _mm_andnot_si128(bxhih, mask); - __m128i bxl = _mm256_castsi256_si128(bx_0); - __m128i bxh = _mm256_extractf128_si256(bx_0, 1); - bxl = _mm_or_si128(bxl, bxhil); - bxh = _mm_or_si128(bxh, bxhih); - bx_0 = MM256_SET_M128I(bxh, bxl); - - const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs); - - const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0); - - /* Multiply q with scale and accumulate */ - acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc); - } - - sumf = hsum_float_8(acc); -#elif defined(__riscv_v_intrinsic) - size_t vl; - size_t vlenb = __riscv_vlenb(); - - for (; ib < nb; ++ib) { - vl = qk / 2; - vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl); - vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl)); - vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl)); - vint8m2_t v0c; - if (vlenb == 16) { - v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h); - } else { - v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32); - v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l); - } - - vl = qk; - vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl); - qh = __riscv_vmnand_mm_b4(qh, qh, vl); - vint8m2_t v0f = __riscv_vsub_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl); - vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl); - vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl); - vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl); - vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl); - int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum); - - sumf += (GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)) * sumi; - } - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector unsigned char v4 = vec_splats((unsigned char)4); - - vector float vsumf0 = vec_splats(0.0f); - -#pragma GCC unroll 4 - for (; ib < nb; ++ib) { - __builtin_prefetch(x[ib].qs, 0, 1); - __builtin_prefetch(y[ib].qs, 0, 1); - - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); - vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); - vector float vd = vec_mul(vxd, vyd); - - vector signed long long aux64x2_0 = {(uint64_t)(table_b2b_1[x[ib].qh[0]]), (uint64_t)(table_b2b_1[x[ib].qh[1]])}; - vector signed long long aux64x2_1 = {(uint64_t)(table_b2b_1[x[ib].qh[2]]), (uint64_t)(table_b2b_1[x[ib].qh[3]])}; - - vector signed char qh0 = (vector signed char)aux64x2_0; - vector signed char qh1 = (vector signed char)aux64x2_1; - - vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); - - vector signed char q5x0 = vec_sub(vec_and (qxs, lowMask), qh0); - vector signed char q5x1 = vec_sub(vec_sr(qxs, v4), qh1); - - vector signed char q8y0 = vec_xl( 0, y[ib].qs); - vector signed char q8y1 = vec_xl( 16, y[ib].qs); - - vector signed short qv0 = vec_add(vec_mule(q5x0, q8y0), vec_mulo(q5x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q5x1, q8y1), vec_mulo(q5x1, q8y1)); - - qv0 = vec_add(qv0, qv1); - - vector signed int vsumi0 = vec_add(vec_unpackh(qv0), vec_unpackl(qv0)); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - } - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - sumf = vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - // Initialize accumulator with zeros - __m256 acc = (__m256)__lasx_xvldi(0); - - // Main loop - for (; ib < nb; ++ib) { - /* Compute combined scale for the block */ - const __m256 d = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); //FIXME - - __m256i qx = bytes_from_nibbles_32(x[ib].qs); - __m256i bxhi = bytes_from_bits_32(x[ib].qh); - bxhi = __lasx_xvandn_v(bxhi, __lasx_xvreplgr2vr_b((char)0xF0)); - qx = __lasx_xvor_v(qx, bxhi); - - __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); - - const __m256 q = mul_sum_i8_pairs_float(qx, qy); - - /* Multiply q with scale and accumulate */ - acc = __lasx_xvfmadd_s(d, q, acc); - } - - sumf = hsum_float_8(acc); -#endif - for (; ib < nb; ++ib) { - uint32_t qh; - memcpy(&qh, x[ib].qh, sizeof(qh)); - - int sumi0 = 0; - int sumi1 = 0; - - for (int j = 0; j < qk/2; ++j) { - const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; - const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); - - const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16); - const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16); - - sumi0 += (x0 * y[ib].qs[j]); - sumi1 += (x1 * y[ib].qs[j + qk/2]); - } - - int sumi = sumi0 + sumi1; - sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi; - } - - *s = sumf; -} - -void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - const int qk = QK8_1; - const int nb = n / qk; - - int ib = 0; - float sumf = 0; - - assert(n % qk == 0); - assert(qk == QK5_1); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q5_1 * GGML_RESTRICT x = vx; - const block_q8_1 * GGML_RESTRICT y = vy; - -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - float summs0 = 0.0f; - float summs1 = 0.0f; - - uint32_t qh0; - uint32_t qh1; - - uint64_t tmp0[4]; - uint64_t tmp1[4]; - - for (; ib + 1 < nb; ib += 2) { - const block_q5_1 * GGML_RESTRICT x0 = &x[ib]; - const block_q5_1 * GGML_RESTRICT x1 = &x[ib + 1]; - const block_q8_1 * GGML_RESTRICT y0 = &y[ib]; - const block_q8_1 * GGML_RESTRICT y1 = &y[ib + 1]; - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - - summs0 += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s); - summs1 += GGML_FP16_TO_FP32(x1->m) * GGML_FP16_TO_FP32(y1->s); - - // extract the 5th bit via lookup table ((b) << 4) - memcpy(&qh0, x0->qh, sizeof(qh0)); - memcpy(&qh1, x1->qh, sizeof(qh1)); - - tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF]; - tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF]; - tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF]; - tmp0[3] = table_b2b_0[(qh0 >> 24) ]; - - tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF]; - tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF]; - tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF]; - tmp1[3] = table_b2b_0[(qh1 >> 24) ]; - - const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); - const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); - const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); - const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - - // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // add high bit - const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0); - const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0); - const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1); - const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1); - - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - ggml_vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), - ggml_vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - ggml_vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), - ggml_vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - } - - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; -#elif defined __wasm_simd128__ - v128_t sumv = wasm_f32x4_splat(0.0f); - - float summs = 0.0f; - - uint32_t qh_; - uint64_t tmp[4]; - - // TODO: check if unrolling this is better - for (; ib < nb; ++ib) { - const block_q5_1 * GGML_RESTRICT x0 = &x[ib]; - const block_q8_1 * GGML_RESTRICT y0 = &y[ib]; - - summs += GGML_FP16_TO_FP32(x0->m) * GGML_FP16_TO_FP32(y0->s); - - const v128_t m4b = wasm_i8x16_splat(0x0F); - - // extract the 5th bit - memcpy(&qh_, x0->qh, sizeof(qh_)); - - tmp[0] = table_b2b_0[(qh_ >> 0) & 0xFF]; - tmp[1] = table_b2b_0[(qh_ >> 8) & 0xFF]; - tmp[2] = table_b2b_0[(qh_ >> 16) & 0xFF]; - tmp[3] = table_b2b_0[(qh_ >> 24) ]; - - const v128_t qhl = wasm_v128_load(tmp + 0); - const v128_t qhh = wasm_v128_load(tmp + 2); - - const v128_t v0 = wasm_v128_load(x0->qs); - - // 4-bit -> 8-bit - const v128_t v0l = wasm_v128_and (v0, m4b); - const v128_t v0h = wasm_u8x16_shr(v0, 4); - - // add high bit - const v128_t v0lf = wasm_v128_or(v0l, qhl); - const v128_t v0hf = wasm_v128_or(v0h, qhh); - - // load y - const v128_t v1l = wasm_v128_load(y0->qs); - const v128_t v1h = wasm_v128_load(y0->qs + 16); - - // int8x16 -> int16x8 - const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); - const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); - const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); - const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); - - const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); - const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); - const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); - const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); - - // dot product - sumv = wasm_f32x4_add(sumv, - wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add( - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), - wasm_i32x4_dot_i16x8(v0lfh, v1lh)), - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), - wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), - wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d)))); - } - - sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + - wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs; -#elif defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - float summs = 0.0f; - - // Main loop - for (; ib < nb; ++ib) { - const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d)); - - summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); - - __m256i qx = bytes_from_nibbles_32(x[ib].qs); - __m256i bxhi = bytes_from_bits_32(x[ib].qh); - bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10)); - qx = _mm256_or_si256(qx, bxhi); - - const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[ib].d)); - const __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); - - const __m256 q = mul_sum_us8_pairs_float(qx, qy); - - acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); - } - - sumf = hsum_float_8(acc) + summs; -#elif defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - __m128i mask = _mm_set1_epi8(0x10); - - float summs = 0.0f; - - // Main loop - for (; ib < nb; ++ib) { - const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d)); - - summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); - - __m256i bx_0 = bytes_from_nibbles_32(x[ib].qs); - const __m256i bxhi = bytes_from_bits_32(x[ib].qh); - __m128i bxhil = _mm256_castsi256_si128(bxhi); - __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); - bxhil = _mm_and_si128(bxhil, mask); - bxhih = _mm_and_si128(bxhih, mask); - __m128i bxl = _mm256_castsi256_si128(bx_0); - __m128i bxh = _mm256_extractf128_si256(bx_0, 1); - bxl = _mm_or_si128(bxl, bxhil); - bxh = _mm_or_si128(bxh, bxhih); - bx_0 = MM256_SET_M128I(bxh, bxl); - - const __m256 dy = _mm256_set1_ps(GGML_FP16_TO_FP32(y[ib].d)); - const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[ib].qs); - - const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0); - - acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc); - } - - sumf = hsum_float_8(acc) + summs; -#elif defined(__riscv_v_intrinsic) - size_t vl; - size_t vlenb = __riscv_vlenb(); - - for (; ib < nb; ++ib) { - vl = qk / 2; - vuint8m1_t v0 = __riscv_vle8_v_u8m1(x[ib].qs, vl); - vint8m1_t v0l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(v0, 0x0F, vl)); - vint8m1_t v0h = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(v0, 4, vl)); - vint8m2_t v0c; - if (vlenb == 16) { - v0c = __riscv_vcreate_v_i8m1_i8m2(v0l, v0h); - } else { - v0l = __riscv_vslideup_vx_i8m1(v0l, v0h, 16, 32); - v0c = __riscv_vlmul_ext_v_i8m1_i8m2(v0l); - } - - vl = qk; - vbool4_t qh = __riscv_vlm_v_b4(x[ib].qh, vl); - vint8m2_t v0f = __riscv_vor_vx_i8m2_mu(qh, v0c, v0c, 0x10, vl); - vint8m2_t v1 = __riscv_vle8_v_i8m2(y[ib].qs, vl); - vint16m4_t mul = __riscv_vwmul_vv_i16m4(v0f, v1, vl); - vint32m1_t zero = __riscv_vmv_v_x_i32m1(0, vl); - vint32m1_t sum = __riscv_vwredsum_vs_i16m4_i32m1(mul, zero, vl); - int32_t sumi = __riscv_vmv_x_s_i32m1_i32(sum); - - sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); - } - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector signed int v0 = vec_splats((int32_t)0); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - - vector float vsumf0 = vec_splats(0.0f); - -#pragma GCC unroll 4 - for (; ib < nb; ++ib) { - __builtin_prefetch(x[ib].qs, 0, 1); - __builtin_prefetch(y[ib].qs, 0, 1); - - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); - vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); - vector float vd = vec_mul(vxd, vyd); - - vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[ib].m)); - vector float vys = {GGML_FP16_TO_FP32(y[ib].s), 0.f, 0.f, 0.f}; - vsumf0 = vec_madd(vxmin, vys, vsumf0); - - vector unsigned long long aux64x2_0 = {(uint64_t)(table_b2b_0[x[ib].qh[0]]), (uint64_t)(table_b2b_0[x[ib].qh[1]])}; - vector unsigned long long aux64x2_1 = {(uint64_t)(table_b2b_0[x[ib].qh[2]]), (uint64_t)(table_b2b_0[x[ib].qh[3]])}; - - vector signed char qh0 = (vector signed char)aux64x2_0; - vector signed char qh1 = (vector signed char)aux64x2_1; - - vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); - - vector unsigned char q5x0 = (vector unsigned char)vec_or(vec_and(qxs, lowMask), qh0); - vector unsigned char q5x1 = (vector unsigned char)vec_or(vec_sr(qxs, v4), qh1); - - vector signed char q8y0 = vec_xl( 0, y[ib].qs); - vector signed char q8y1 = vec_xl( 16, y[ib].qs); - - vector signed int vsumi0 = v0; - - vsumi0 = vec_msum(q8y0, q5x0, vsumi0); - vsumi0 = vec_msum(q8y1, q5x1, vsumi0); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - } - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - sumf = vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - // Initialize accumulator with zeros - __m256 acc = (__m256)__lasx_xvldi(0); - - float summs = 0.0f; - - // Main loop - for (; ib < nb; ++ib) { - const __m256 dx = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d)); - - summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s); - - __m256i qx = bytes_from_nibbles_32(x[ib].qs); - __m256i bxhi = bytes_from_bits_32(x[ib].qh); - bxhi = __lasx_xvand_v(bxhi, __lasx_xvreplgr2vr_b(0x10)); - qx = __lasx_xvor_v(qx, bxhi); - - const __m256 dy = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib].d)); - const __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); - - const __m256 q = mul_sum_us8_pairs_float(qx, qy); - - acc = __lasx_xvfmadd_s(q, __lasx_xvfmul_s(dx, dy), acc); - } - - sumf = hsum_float_8(acc) + summs; -#endif - for (; ib < nb; ++ib) { - uint32_t qh; - memcpy(&qh, x[ib].qh, sizeof(qh)); - - int sumi0 = 0; - int sumi1 = 0; - - for (int j = 0; j < qk/2; ++j) { - const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; - const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; - - const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0; - const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1; - - sumi0 += (x0 * y[ib].qs[j]); - sumi1 += (x1 * y[ib].qs[j + qk/2]); - } - - int sumi = sumi0 + sumi1; - sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); - } - - *s = sumf; -} - -void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - const int qk = QK8_0; - const int nb = n / qk; - - assert(n % qk == 0); -#if defined(__ARM_FEATURE_MATMUL_INT8) - assert((nrc == 2) || (nrc == 1)); -#else - assert(nrc == 1); -#endif - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q8_0 * GGML_RESTRICT x = vx; - const block_q8_0 * GGML_RESTRICT y = vy; - -#if defined(__ARM_FEATURE_MATMUL_INT8) - if (nrc == 2) { - const block_q8_0 * GGML_RESTRICT vx0 = vx; - const block_q8_0 * GGML_RESTRICT vx1 = (const block_q8_0 *) ((const uint8_t*)vx + bx); - const block_q8_0 * GGML_RESTRICT vy0 = vy; - const block_q8_0 * GGML_RESTRICT vy1 = (const block_q8_0 *) ((const uint8_t*)vy + by); - - float32x4_t sumv0 = vdupq_n_f32(0.0f); - - for (int i = 0; i < nb; i++) { - const block_q8_0 * GGML_RESTRICT b_x0 = &vx0[i]; - const block_q8_0 * GGML_RESTRICT b_y0 = &vy0[i]; - - const block_q8_0 * GGML_RESTRICT b_x1 = &vx1[i]; - const block_q8_0 * GGML_RESTRICT b_y1 = &vy1[i]; - - const int8x16_t x0_l = vld1q_s8(b_x0->qs); - const int8x16_t x0_h = vld1q_s8(b_x0->qs + 16); - const int8x16_t x1_l = vld1q_s8(b_x1->qs); - const int8x16_t x1_h = vld1q_s8(b_x1->qs + 16); - - // load y - const int8x16_t y0_l = vld1q_s8(b_y0->qs); - const int8x16_t y0_h = vld1q_s8(b_y0->qs + 16); - const int8x16_t y1_l = vld1q_s8(b_y1->qs); - const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); - - float32_t _scale[4] = { - GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y0->d), - GGML_FP16_TO_FP32(b_x0->d)*GGML_FP16_TO_FP32(b_y1->d), - GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y0->d), - GGML_FP16_TO_FP32(b_x1->d)*GGML_FP16_TO_FP32(b_y1->d) - }; - float32x4_t scale = vld1q_f32(_scale); - - int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); - int8x16_t l1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); - - int8x16_t l2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); - int8x16_t l3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(x0_h), vreinterpretq_s64_s8(x1_h))); - - int8x16_t r0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); - int8x16_t r1 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_l), vreinterpretq_s64_s8(y1_l))); - - int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); - int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); - - sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), - l1, r1)), l2, r2)), l3, r3))), scale); - } - - float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2); - float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); - - vst1_f32(s, vget_low_f32 (sumv2)); - vst1_f32(s + bs, vget_high_f32(sumv2)); - - return; - } -#endif - - int ib = 0; - float sumf = 0; - -#if defined(__ARM_FEATURE_SVE) - svfloat32_t sumv0 = svdup_n_f32(0.0f); - svfloat32_t sumv1 = svdup_n_f32(0.0f); - - const int vector_length = ggml_cpu_get_sve_cnt()*8; - - //VLA Implemenation for SVE - switch (vector_length) { - case 128: - { - // predicate for activating lanes for 16 Int8 elements - const svbool_t ph16 = svptrue_pat_b8 (SV_VL16); - const svbool_t pl16 = svptrue_pat_b32(SV_VL4); - - for (; ib + 1 < nb; ib += 2) { - const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0]; - const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1]; - const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; - const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; - - // load x - const svint8_t qx0_0 = svld1_s8(ph16, x0->qs); - const svint8_t qx0_1 = svld1_s8(ph16, x0->qs+16); - const svint8_t qx1_0 = svld1_s8(ph16, x1->qs); - const svint8_t qx1_1 = svld1_s8(ph16, x1->qs+16); - - // load y - const svint8_t qy0_0 = svld1_s8(ph16, y0->qs); - const svint8_t qy0_1 = svld1_s8(ph16, y0->qs+16); - const svint8_t qy1_0 = svld1_s8(ph16, y1->qs); - const svint8_t qy1_1 = svld1_s8(ph16, y1->qs+16); - - sumv0 = svmla_n_f32_x(pl16, sumv0, svcvt_f32_s32_x(pl16, svadd_x(pl16, - svdot_s32(svdup_n_s32(0), qx0_0, qy0_0), - svdot_s32(svdup_n_s32(0), qx0_1, qy0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(pl16, sumv1, svcvt_f32_s32_x(pl16, svadd_x(pl16, - svdot_s32(svdup_n_s32(0), qx1_0, qy1_0), - svdot_s32(svdup_n_s32(0), qx1_1, qy1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - } - - sumf = svaddv_f32(pl16, svadd_f32_x(pl16, sumv0, sumv1)); - } break; - case 256: - { - //printf("sve256"); - for (; ib + 1 < nb; ib += 2) { - const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0]; - const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1]; - const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; - const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; - - // load x - const svint8_t qx0 = svld1_s8(svptrue_b8(), x0->qs); - const svint8_t qx1 = svld1_s8(svptrue_b8(), x1->qs); - - // load y - const svint8_t qy0 = svld1_s8(svptrue_b8(), y0->qs); - const svint8_t qy1 = svld1_s8(svptrue_b8(), y1->qs); - - sumv0 = svmla_n_f32_x(svptrue_b32(), sumv0, svcvt_f32_s32_x(svptrue_b32(), - svdot_s32(svdup_n_s32(0), qx0, qy0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - sumv1 = svmla_n_f32_x(svptrue_b32(), sumv1, svcvt_f32_s32_x(svptrue_b32(), - svdot_s32(svdup_n_s32(0), qx1, qy1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - } - - sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1)); - } break; - case 512: - { - // predicate for activating high 256 bit - const svbool_t ph32 = svptrue_pat_b8(SV_VL32); - // predicate for activating low 256 bit - const svbool_t pl32 = svnot_b_z(svptrue_b8(), ph32); - - // predicate for activating high lanes for 8 float32 elements - const svbool_t ph8 = svptrue_pat_b32(SV_VL8); - // predicate for activating low lanes for 8 float32 elements - const svbool_t pl8 = svnot_b_z(svptrue_b32(), ph8); - - svfloat32_t sumv00 = svdup_n_f32(0.0f); - - for (; ib + 1 < nb; ib += 2) { - const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0]; - const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1]; - const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; - const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; - - //load 32 int8_t in first half of vector and put another 32 int8_t in second vector lower bits - // and add them to make one 64 element vector - // load x - const svint8_t qx_32 = svld1_s8(ph32, x0->qs); - svint8_t qx_64 = svld1_s8(pl32, x0->qs + 2); - - qx_64 = svadd_s8_x(svptrue_b8(), qx_32, qx_64); - - // load y - const svint8_t qy_32 = svld1_s8(ph32, y0->qs); - svint8_t qy_64 = svld1_s8(pl32, y0->qs + 2); - - qy_64 = svadd_s8_x(svptrue_b8(), qy_32, qy_64); - - // scale creation - const float32_t deq1 = GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d); - const float32_t deq2 = GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d); - - // duplicate deq1 in first half of vector and deq2 in second half of vector - const svfloat32_t temp = svdup_f32_m(svdup_f32_z(ph8, deq1), pl8, deq2); - - const svfloat32_t sumvt = svcvt_f32_s32_x(svptrue_b32(), svdot_s32(svdup_n_s32(0), qx_64, qy_64)); - - sumv00 = svmla_f32_m(svptrue_b32(), sumv00, sumvt, temp); - } - - sumf = svaddv_f32(svptrue_b32(), sumv00); - break; - } - default: - assert(false && "Unsupported vector length"); - break; - } -#elif defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - for (; ib + 1 < nb; ib += 2) { - const block_q8_0 * GGML_RESTRICT x0 = &x[ib + 0]; - const block_q8_0 * GGML_RESTRICT x1 = &x[ib + 1]; - const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0]; - const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1]; - - const int8x16_t x0_0 = vld1q_s8(x0->qs); - const int8x16_t x0_1 = vld1q_s8(x0->qs + 16); - const int8x16_t x1_0 = vld1q_s8(x1->qs); - const int8x16_t x1_1 = vld1q_s8(x1->qs + 16); - - // load y - const int8x16_t y0_0 = vld1q_s8(y0->qs); - const int8x16_t y0_1 = vld1q_s8(y0->qs + 16); - const int8x16_t y1_0 = vld1q_s8(y1->qs); - const int8x16_t y1_1 = vld1q_s8(y1->qs + 16); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - ggml_vdotq_s32(vdupq_n_s32(0), x0_0, y0_0), - ggml_vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); - - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - ggml_vdotq_s32(vdupq_n_s32(0), x1_0, y1_0), - ggml_vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); - } - - sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined __wasm_simd128__ - v128_t sumv = wasm_f32x4_splat(0.0f); - - for (; ib < nb; ++ib) { - const block_q8_0 * GGML_RESTRICT x0 = &x[ib]; - const block_q8_0 * GGML_RESTRICT y0 = &y[ib]; - - const v128_t x0_0 = wasm_v128_load(x0->qs); - const v128_t x0_1 = wasm_v128_load(x0->qs + 16); - const v128_t y0_0 = wasm_v128_load(y0->qs); - const v128_t y0_1 = wasm_v128_load(y0->qs + 16); - - // Extend 8-bit to 16-bit - const v128_t x0_0l = wasm_i16x8_extend_low_i8x16(x0_0); - const v128_t x0_0h = wasm_i16x8_extend_high_i8x16(x0_0); - const v128_t x0_1l = wasm_i16x8_extend_low_i8x16(x0_1); - const v128_t x0_1h = wasm_i16x8_extend_high_i8x16(x0_1); - - const v128_t y0_0l = wasm_i16x8_extend_low_i8x16(y0_0); - const v128_t y0_0h = wasm_i16x8_extend_high_i8x16(y0_0); - const v128_t y0_1l = wasm_i16x8_extend_low_i8x16(y0_1); - const v128_t y0_1h = wasm_i16x8_extend_high_i8x16(y0_1); - - // Compute dot products - const v128_t dx0_0 = wasm_i32x4_dot_i16x8(x0_0l, y0_0l); - const v128_t dx0_1 = wasm_i32x4_dot_i16x8(x0_0h, y0_0h); - const v128_t dx1_0 = wasm_i32x4_dot_i16x8(x0_1l, y0_1l); - const v128_t dx1_1 = wasm_i32x4_dot_i16x8(x0_1h, y0_1h); - - // Sum all dot products - const v128_t sum_dots = wasm_i32x4_add(wasm_i32x4_add(dx0_0, dx0_1), wasm_i32x4_add(dx1_0, dx1_1)); - - // Convert to float and accumulate - const float scale = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d); - sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(sum_dots), wasm_f32x4_splat(scale))); - } - - sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + - wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); -#elif defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - // Main loop - for (; ib < nb; ++ib) { - // Compute combined scale for the block - const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); - __m256i qx = _mm256_loadu_si256((const __m256i *)x[ib].qs); - __m256i qy = _mm256_loadu_si256((const __m256i *)y[ib].qs); - - const __m256 q = mul_sum_i8_pairs_float(qx, qy); - - // Multiply q with scale and accumulate - acc = _mm256_fmadd_ps( d, q, acc ); - } - - sumf = hsum_float_8(acc); -#elif defined(__AVX__) - __m256 accum = _mm256_setzero_ps(); - - for (; ib + 1 < nb; ib += 2) { - const __m128i qx_1_0 = _mm_loadu_si128((const __m128i *)x[ib].qs); - const __m128i qx_1_1 = _mm_loadu_si128((const __m128i *)x[ib].qs + 1); - const __m128i qx_2_0 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); - const __m128i qx_2_1 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs + 1); - const __m128i qy_1_0 = _mm_loadu_si128((const __m128i *)y[ib].qs); - const __m128i qy_1_1 = _mm_loadu_si128((const __m128i *)y[ib].qs + 1); - const __m128i qy_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); - const __m128i qy_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1); - - const __m256 p = mul_sum_i8_quad_float(qx_1_0, qx_1_1, qx_2_0, qx_2_1, qy_1_0, qy_1_1, qy_2_0, qy_2_1); - const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d); - accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum); - } - - sumf = hsum_float_8(accum); -#elif defined(__riscv_v_intrinsic) - size_t vl = qk; - - for (; ib < nb; ++ib) { - // load elements - vint8m2_t bx_0 = __riscv_vle8_v_i8m2(x[ib].qs, vl); - vint8m2_t by_0 = __riscv_vle8_v_i8m2(y[ib].qs, vl); - - vint16m4_t vw_mul = __riscv_vwmul_vv_i16m4(bx_0, by_0, vl); - - vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl); - vint32m1_t v_sum = __riscv_vwredsum_vs_i16m4_i32m1(vw_mul, v_zero, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum); - - sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)); - } -#elif defined(__POWER9_VECTOR__) - const vector signed int v0 = vec_splats((int32_t)0); - vector float vsumf0 = vec_splats(0.0f); - -#pragma GCC unroll 8 - for (; ib < nb; ++ib) { - __builtin_prefetch(x[ib].qs, 0, 1); - __builtin_prefetch(y[ib].qs, 0, 1); - - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); - vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); - vector float vd = vec_mul(vxd, vyd); - - vector signed char q8x0 = vec_xl( 0, x[ib].qs); - vector signed char q8x1 = vec_xl(16, x[ib].qs); - vector signed char q8y0 = vec_xl( 0, y[ib].qs); - vector signed char q8y1 = vec_xl(16, y[ib].qs); - - vector signed short qv0 = vec_mule(q8x0, q8y0); - vector signed short qv1 = vec_mulo(q8x0, q8y0); - vector signed short qv2 = vec_mule(q8x1, q8y1); - vector signed short qv3 = vec_mulo(q8x1, q8y1); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - - vsumi0 = vec_sum4s(qv0, vsumi0); - vsumi1 = vec_sum4s(qv1, vsumi1); - vsumi0 = vec_sum4s(qv2, vsumi0); - vsumi1 = vec_sum4s(qv3, vsumi1); - - vsumi0 = vec_add(vsumi0, vsumi1); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - } - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - sumf = vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - // Initialize accumulator with zeros - __m256 acc = (__m256)__lasx_xvldi(0); - - // Main loop - for (; ib < nb; ++ib) { - // Compute combined scale for the block - const __m256 d = __lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); - __m256i qx = __lasx_xvld((const __m256i *)x[ib].qs, 0); - __m256i qy = __lasx_xvld((const __m256i *)y[ib].qs, 0); - - const __m256 q = mul_sum_i8_pairs_float(qx, qy); - - // Multiply q with scale and accumulate - acc = __lasx_xvfmadd_s( d, q, acc ); - } - - sumf = hsum_float_8(acc); -#elif defined(__VXE__) || defined(__VXE2__) - __vector float acc = vec_splats(0.0f); - -#pragma GCC unroll 8 - for (; ib < nb; ++ib) { - __builtin_prefetch(x[ib].qs, 0, 1); - __builtin_prefetch(y[ib].qs, 0, 1); - - const int8x16_t v_xl = vec_xl(0 , x[ib].qs); - const int8x16_t v_xh = vec_xl(QK8_0/2, x[ib].qs); - const int8x16_t v_yl = vec_xl(0 , y[ib].qs); - const int8x16_t v_yh = vec_xl(QK8_0/2, y[ib].qs); - - const int32x4_t v_xy_ = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh); - const float32x4_t v_xy = vec_float(v_xy_); - const float32x4_t v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d)); - - acc = vec_madd(v_xy, v_d, acc); - } - - sumf = acc[0] + acc[1] + acc[2] + acc[3]; -#endif - for (; ib < nb; ++ib) { - int sumi = 0; - - for (int j = 0; j < qk; j++) { - sumi += x[ib].qs[j]*y[ib].qs[j]; - } - - sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)); - } - - *s = sumf; -} - -void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_tq1_0 * GGML_RESTRICT x = vx; - const block_q8_K * GGML_RESTRICT y = vy; - - const int nb = n / QK_K; - -#if defined(__ARM_NEON) - float sumf = 0.0f; - - uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; - - const uint8x16_t shift = vld1q_u8(k_shift); - - for (int i = 0; i < nb; ++i) { -#if defined(__ARM_FEATURE_DOTPROD) - int32x4_t sumi0 = vdupq_n_s32(0); - int32x4_t sumi1 = vdupq_n_s32(0); -#else - int16x8_t sumi0 = vdupq_n_s16(0); - int16x8_t sumi1 = vdupq_n_s16(0); -#endif - - // first 32 bytes of 5 elements - { - uint8x16_t qx0 = vld1q_u8(x[i].qs + 0); - uint8x16_t qx1 = vld1q_u8(x[i].qs + 16); - uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3)); - uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3)); - uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9)); - uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9)); - uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27)); - uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27)); - uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81)); - uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81)); - - // multiply by 3 and keep the 2 bits above 8 bits - int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6)); - int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6)); - int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6)); - int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6)); - int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6)); - int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6)); - int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6)); - int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6)); - int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6)); - int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6)); - - const int8x16_t qy0 = vld1q_s8(y[i].qs + 0); - const int8x16_t qy1 = vld1q_s8(y[i].qs + 16); - const int8x16_t qy2 = vld1q_s8(y[i].qs + 32); - const int8x16_t qy3 = vld1q_s8(y[i].qs + 48); - const int8x16_t qy4 = vld1q_s8(y[i].qs + 64); - const int8x16_t qy5 = vld1q_s8(y[i].qs + 80); - const int8x16_t qy6 = vld1q_s8(y[i].qs + 96); - const int8x16_t qy7 = vld1q_s8(y[i].qs + 112); - const int8x16_t qy8 = vld1q_s8(y[i].qs + 128); - const int8x16_t qy9 = vld1q_s8(y[i].qs + 144); - -#if defined(__ARM_FEATURE_DOTPROD) - sumi0 = vdotq_s32(sumi0, sqx0, qy0); - sumi1 = vdotq_s32(sumi1, sqx1, qy1); - sumi0 = vdotq_s32(sumi0, sqx2, qy2); - sumi1 = vdotq_s32(sumi1, sqx3, qy3); - sumi0 = vdotq_s32(sumi0, sqx4, qy4); - sumi1 = vdotq_s32(sumi1, sqx5, qy5); - sumi0 = vdotq_s32(sumi0, sqx6, qy6); - sumi1 = vdotq_s32(sumi1, sqx7, qy7); - sumi0 = vdotq_s32(sumi0, sqx8, qy8); - sumi1 = vdotq_s32(sumi1, sqx9, qy9); -#else - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx8), vget_low_s8(qy8)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9)); -#endif - } - - // last 16 bytes of 5-element, along with the 4 bytes of 4 elements - { - uint8x16_t qx0 = vld1q_u8(x[i].qs + 32); - uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3)); - uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9)); - uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27)); - uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81)); - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned - uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh)); - qx5 = vmulq_u8(qx5, shift); - - // multiply by 3 and keep the 2 bits above 8 bits - int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6)); - int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6)); - int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6)); - int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6)); - int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6)); - int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6)); - - const int8x16_t qy0 = vld1q_s8(y[i].qs + 160); - const int8x16_t qy1 = vld1q_s8(y[i].qs + 176); - const int8x16_t qy2 = vld1q_s8(y[i].qs + 192); - const int8x16_t qy3 = vld1q_s8(y[i].qs + 208); - const int8x16_t qy4 = vld1q_s8(y[i].qs + 224); - const int8x16_t qy5 = vld1q_s8(y[i].qs + 240); - -#if defined(__ARM_FEATURE_DOTPROD) - sumi0 = vdotq_s32(sumi0, sqx0, qy0); - sumi1 = vdotq_s32(sumi1, sqx1, qy1); - sumi0 = vdotq_s32(sumi0, sqx2, qy2); - sumi1 = vdotq_s32(sumi1, sqx3, qy3); - sumi0 = vdotq_s32(sumi0, sqx4, qy4); - sumi1 = vdotq_s32(sumi1, sqx5, qy5); -#else - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); -#endif - } - - const int16x8_t ysum0 = vld1q_s16(y[i].bsums); - const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8); - - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - -#if defined(__ARM_FEATURE_DOTPROD) - sumi0 = vaddq_s32(sumi0, sumi1); - sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1))); - - sumf += d * (float) vaddvq_s32(sumi0); -#else - sumi0 = vaddq_s16(sumi0, sumi1); - sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1)); - - sumf += d * (float) vaddlvq_s16(sumi0); -#endif - } - - *s = sumf; - -#elif defined(__AVX2__) - __m256 sumf = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - // 16-bit sums - __m256i sumi0 = _mm256_setzero_si256(); - __m256i sumi1 = _mm256_setzero_si256(); - __m256i sumi2 = _mm256_setzero_si256(); - - // first 32 bytes of 5 elements - { - __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs)); - // 8-bit multiplies with shifts, masks and adds - __m256i qx1 = _mm256_add_epi8(qx0, _mm256_add_epi8(qx0, qx0)); // 1 * 3 - __m256i qx2 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx0, 3), _mm256_set1_epi8(-8)), qx0); // 1 * 9 - __m256i qx3 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx1, 3), _mm256_set1_epi8(-8)), qx1); // 3 * 9 - __m256i qx4 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx2, 3), _mm256_set1_epi8(-8)), qx2); // 9 * 9 - - // TODO: can _mm256_mulhi_epu16 be faster even if 16-bits? - - // Cancel the +1 from avg so that it behaves like a halving add - qx0 = _mm256_subs_epu8(qx0, _mm256_set1_epi8(1)); - qx1 = _mm256_subs_epu8(qx1, _mm256_set1_epi8(1)); - qx2 = _mm256_subs_epu8(qx2, _mm256_set1_epi8(1)); - qx3 = _mm256_subs_epu8(qx3, _mm256_set1_epi8(1)); - qx4 = _mm256_subs_epu8(qx4, _mm256_set1_epi8(1)); - // Multiply by 3 and get the top 2 bits - qx0 = _mm256_avg_epu8(qx0, _mm256_avg_epu8(qx0, _mm256_setzero_si256())); - qx1 = _mm256_avg_epu8(qx1, _mm256_avg_epu8(qx1, _mm256_setzero_si256())); - qx2 = _mm256_avg_epu8(qx2, _mm256_avg_epu8(qx2, _mm256_setzero_si256())); - qx3 = _mm256_avg_epu8(qx3, _mm256_avg_epu8(qx3, _mm256_setzero_si256())); - qx4 = _mm256_avg_epu8(qx4, _mm256_avg_epu8(qx4, _mm256_setzero_si256())); - qx0 = _mm256_and_si256(_mm256_srli_epi16(qx0, 6), _mm256_set1_epi8(3)); - qx1 = _mm256_and_si256(_mm256_srli_epi16(qx1, 6), _mm256_set1_epi8(3)); - qx2 = _mm256_and_si256(_mm256_srli_epi16(qx2, 6), _mm256_set1_epi8(3)); - qx3 = _mm256_and_si256(_mm256_srli_epi16(qx3, 6), _mm256_set1_epi8(3)); - qx4 = _mm256_and_si256(_mm256_srli_epi16(qx4, 6), _mm256_set1_epi8(3)); - - const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 0)); - const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 32)); - const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 64)); - const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 96)); - const __m256i qy4 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 128)); - - qx0 = _mm256_maddubs_epi16(qx0, qy0); - qx1 = _mm256_maddubs_epi16(qx1, qy1); - qx2 = _mm256_maddubs_epi16(qx2, qy2); - qx3 = _mm256_maddubs_epi16(qx3, qy3); - qx4 = _mm256_maddubs_epi16(qx4, qy4); - - sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1)); - sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3)); - sumi2 = _mm256_add_epi16(sumi2, qx4); - } - - // last 16 bytes of 5-element, along with the 4 bytes of 4 elements - { - __m128i qx0 = _mm_loadu_si128((const __m128i *) (x[i].qs + 32)); - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned - __m256i qx5_l = _mm256_cvtepu8_epi16(_mm_set1_epi32(qh)); - __m128i qx1 = _mm_add_epi8(qx0, _mm_add_epi8(qx0, qx0)); // 1 * 3 - __m128i qx2 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx0, 3), _mm_set1_epi8(-8)), qx0); // 1 * 9 - __m128i qx3 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx1, 3), _mm_set1_epi8(-8)), qx1); // 3 * 9 - __m128i qx4 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx2, 3), _mm_set1_epi8(-8)), qx2); // 9 * 9 - __m256i qx01 = MM256_SET_M128I(qx1, qx0); - __m256i qx23 = MM256_SET_M128I(qx3, qx2); - - // avx2 does not have 8-bit multiplies, so 16-bit it is. - qx5_l = _mm256_mullo_epi16(qx5_l, _mm256_set_epi16(27, 27, 27, 27, 9, 9, 9, 9, 3, 3, 3, 3, 1, 1, 1, 1)); - qx5_l = _mm256_and_si256(qx5_l, _mm256_set1_epi16(0xFF)); - __m128i qx5 = _mm_packus_epi16(_mm256_castsi256_si128(qx5_l), _mm256_extracti128_si256(qx5_l, 1)); - - __m256i qx45 = MM256_SET_M128I(qx5, qx4); - - // Cancel the +1 from avg so that it behaves like a halving add - qx01 = _mm256_subs_epu8(qx01, _mm256_set1_epi8(1)); - qx23 = _mm256_subs_epu8(qx23, _mm256_set1_epi8(1)); - qx45 = _mm256_subs_epu8(qx45, _mm256_set1_epi8(1)); - // Multiply by 3 and get the top 2 bits - qx01 = _mm256_avg_epu8(qx01, _mm256_avg_epu8(qx01, _mm256_setzero_si256())); - qx23 = _mm256_avg_epu8(qx23, _mm256_avg_epu8(qx23, _mm256_setzero_si256())); - qx45 = _mm256_avg_epu8(qx45, _mm256_avg_epu8(qx45, _mm256_setzero_si256())); - qx01 = _mm256_and_si256(_mm256_srli_epi16(qx01, 6), _mm256_set1_epi8(3)); - qx23 = _mm256_and_si256(_mm256_srli_epi16(qx23, 6), _mm256_set1_epi8(3)); - qx45 = _mm256_and_si256(_mm256_srli_epi16(qx45, 6), _mm256_set1_epi8(3)); - - const __m256i qy01 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 160)); - const __m256i qy23 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 192)); - const __m256i qy45 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 224)); - - qx01 = _mm256_maddubs_epi16(qx01, qy01); - qx23 = _mm256_maddubs_epi16(qx23, qy23); - qx45 = _mm256_maddubs_epi16(qx45, qy45); - - sumi0 = _mm256_add_epi16(sumi0, qx01); - sumi1 = _mm256_add_epi16(sumi1, qx23); - sumi2 = _mm256_add_epi16(sumi2, qx45); - } - - const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums); - const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)); - - sumi0 = _mm256_sub_epi16(sumi0, ysum); - sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(sumi1, sumi2)); - sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1)); - - sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf); - } - - *s = hsum_float_8(sumf); - -#else - const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243}; - - float sumf = 0.0f; - - for (int i = 0; i < nb; ++i) { - int sum = 0; - - for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) { - for (size_t l = 0; l < 5; ++l) { - for (size_t m = 0; m < 32; ++m) { - uint8_t q = x[i].qs[j + m] * pow3[l]; - uint16_t xi = ((uint16_t) q * 3) >> 8; - sum += (xi - 1) * y[i].qs[j*5 + l*32 + m]; - } - } - } - for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) { - for (size_t l = 0; l < 5; ++l) { - for (size_t m = 0; m < 16; ++m) { - uint8_t q = x[i].qs[j + m] * pow3[l]; - uint16_t xi = ((uint16_t) q * 3) >> 8; - sum += (xi - 1) * y[i].qs[j*5 + l*16 + m]; - } - } - } - - for (size_t l = 0; l < 4; ++l) { - for (size_t j = 0; j < sizeof(x->qh); ++j) { - uint8_t q = x[i].qh[j] * pow3[l]; - uint16_t xi = ((uint16_t) q * 3) >> 8; - sum += (xi - 1) * y[i].qs[sizeof(x->qs)*5 + l*sizeof(x->qh) + j]; - } - } - - sumf += (float) sum * (GGML_FP16_TO_FP32(x[i].d) * y[i].d); - } - - *s = sumf; -#endif -} - -void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_tq2_0 * GGML_RESTRICT x = vx; - const block_q8_K * GGML_RESTRICT y = vy; - - const int nb = n / QK_K; - -#if defined(__ARM_NEON) - float sumf = 0.0f; - - const uint8x16_t m3 = vdupq_n_u8(3); - - for (int i = 0; i < nb; ++i) { -#if defined(__ARM_FEATURE_DOTPROD) - int32x4_t sumi0 = vdupq_n_s32(0); - int32x4_t sumi1 = vdupq_n_s32(0); -#else - int16x8_t sumi0 = vdupq_n_s16(0); - int16x8_t sumi1 = vdupq_n_s16(0); -#endif - - for (size_t j = 0; j < sizeof(x->qs); j += 32) { - uint8x16_t qx0 = vld1q_u8(x[i].qs + j); - uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16); - uint8x16_t qx2 = vshrq_n_u8(qx0, 2); - uint8x16_t qx3 = vshrq_n_u8(qx1, 2); - uint8x16_t qx4 = vshrq_n_u8(qx0, 4); - uint8x16_t qx5 = vshrq_n_u8(qx1, 4); - uint8x16_t qx6 = vshrq_n_u8(qx0, 6); - uint8x16_t qx7 = vshrq_n_u8(qx1, 6); - - int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3)); - int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3)); - int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3)); - int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3)); - int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3)); - int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3)); - int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3)); - int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3)); - - const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 + 0); - const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 + 16); - const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 + 32); - const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 + 48); - const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 + 64); - const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 + 80); - const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96); - const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112); - -#if defined(__ARM_FEATURE_DOTPROD) - sumi0 = vdotq_s32(sumi0, sqx0, qy0); - sumi1 = vdotq_s32(sumi1, sqx1, qy1); - sumi0 = vdotq_s32(sumi0, sqx2, qy2); - sumi1 = vdotq_s32(sumi1, sqx3, qy3); - sumi0 = vdotq_s32(sumi0, sqx4, qy4); - sumi1 = vdotq_s32(sumi1, sqx5, qy5); - sumi0 = vdotq_s32(sumi0, sqx6, qy6); - sumi1 = vdotq_s32(sumi1, sqx7, qy7); -#else - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6)); - sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7)); - sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7)); -#endif - } - - const int16x8_t ysum0 = vld1q_s16(y[i].bsums); - const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8); - - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - -#if defined(__ARM_FEATURE_DOTPROD) - sumi0 = vaddq_s32(sumi0, sumi1); - sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1))); - - sumf += d * (float) vaddvq_s32(sumi0); -#else - sumi0 = vaddq_s16(sumi0, sumi1); - sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1)); - - sumf += d * (float) vaddlvq_s16(sumi0); -#endif - } - - *s = sumf; - -#elif defined(__AVX2__) - __m256 sumf = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - // 16-bit sums, because 256*127 still fits - __m256i sumi0 = _mm256_setzero_si256(); - __m256i sumi1 = _mm256_setzero_si256(); - - for (size_t j = 0; j < sizeof(x->qs); j += 32) { - __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs + j)); - __m256i qx1 = _mm256_srli_epi16(qx0, 2); - __m256i qx2 = _mm256_srli_epi16(qx0, 4); - __m256i qx3 = _mm256_srli_epi16(qx0, 6); - - // 0, 1, 2 (should not be 3) - qx0 = _mm256_and_si256(qx0, _mm256_set1_epi8(3)); - qx1 = _mm256_and_si256(qx1, _mm256_set1_epi8(3)); - qx2 = _mm256_and_si256(qx2, _mm256_set1_epi8(3)); - qx3 = _mm256_and_si256(qx3, _mm256_set1_epi8(3)); - - const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 0)); - const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 32)); - const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 64)); - const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 96)); - - qx0 = _mm256_maddubs_epi16(qx0, qy0); - qx1 = _mm256_maddubs_epi16(qx1, qy1); - qx2 = _mm256_maddubs_epi16(qx2, qy2); - qx3 = _mm256_maddubs_epi16(qx3, qy3); - - sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1)); - sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3)); - } - - const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums); - const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)); - - sumi0 = _mm256_add_epi16(sumi0, sumi1); - sumi0 = _mm256_sub_epi16(sumi0, ysum); - sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1)); - - sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf); - } - - *s = hsum_float_8(sumf); - -#else - float sumf = 0.0f; - - for (int i = 0; i < nb; ++i) { - int32_t sumi = 0; - - for (size_t j = 0; j < sizeof(x->qs); j += 32) { - for (size_t l = 0; l < 4; ++l) { - for (size_t k = 0; k < 32; ++k) { - sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].qs[j + k] >> (l*2)) & 3) - 1); - } - } - } - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - sumf += (float) sumi * d; - } - - *s = sumf; -#endif -} - -void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q2_K * GGML_RESTRICT x = vx; - const block_q8_K * GGML_RESTRICT y = vy; - - const int nb = n / QK_K; - -#ifdef __ARM_FEATURE_SVE - const int vector_length = svcntb()*8; - const svuint8_t m3s = svdup_n_u8(0x3); - const svuint32_t m4s = svdup_n_u32(0xF); - const svint32_t vzero_sv = svdup_n_s32(0); - svfloat32_t acc_sum = svdup_n_f32(0); - svbool_t pred_s32 = svptrue_pat_b32(SV_VL4); - - switch (vector_length) { - case 128: - for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - svfloat32_t d_broad = svdup_n_f32((float32_t)d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin); - - const uint8_t * GGML_RESTRICT q2 = x[i].qs; - const int8_t * GGML_RESTRICT q8_sv = y[i].qs; - const uint8_t * GGML_RESTRICT sc = x[i].scales; - - svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc); - const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4)); - - mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+4); - const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4)); - - svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums); - svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+4); - - const svint32_t s0 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_2, q8sums_sv_2)); - - mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+8); - const svint32_t mins_sv_3 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4)); - - mins_and_scales_sve = svld1ub_u32(svptrue_b32(), sc+12); - const svint32_t mins_sv_4 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_b32(), mins_and_scales_sve, 4)); - - q8sums_sv_1 = svld1sh_s32(svptrue_b32(), y[i].bsums+8); - q8sums_sv_2 = svld1sh_s32(svptrue_b32(), y[i].bsums+12); - - svint32_t s1 = svadd_s32_x(svptrue_b32(), svmul_s32_x(svptrue_b32(), mins_sv_3, q8sums_sv_1), svmul_s32_x(svptrue_b32(), mins_sv_4, q8sums_sv_2)); - - svfloat32_t temp = svcvt_f32_s32_x(svptrue_b32(), svadd_s32_x(svptrue_b32(), s0, s1)); - - acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, temp, dmin_broad); - - svint32_t sumi1 = svdup_n_s32(0); - - { - const svuint8_t q2bits_1 = svld1_u8(svptrue_b8(), q2); - svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_1, m3s)); - svint8_t q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; - const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc), m4s)); - - sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 0)); - - const svuint8_t q2bits_3 = svld1_u8(svptrue_b8(), q2+16); - q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_3, m3s)); - q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; - - sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 1)); - - q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 2), m3s)); - q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; - - sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 2)); - - q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 2), m3s)); - q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; - - sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv, 3)); - - - const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+4), m4s)); - - q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 4), m3s)); - q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; - - sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 0)); - - q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 4), m3s)); - q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; - - sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 1)); - - q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_1, 6), m3s)); - q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; - - sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 2)); - - q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_3, 6), m3s)); - q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; - - sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_1, 3)); - - //------------------------------- - - q2 += 32; - const svint32_t scales_sv_2 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+8), m4s)); - const svuint8_t q2bits_2 = svld1_u8(svptrue_b8(), q2); - - q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_2, m3s)); - q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; - - sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 0)); - - const svuint8_t q2bits_4 = svld1_u8(svptrue_b8(), q2+16); - q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), q2bits_4, m3s)); - q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; - - sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 1)); - - - q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 2), m3s)); - q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; - - sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 2)); - - q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 2), m3s)); - q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; - - sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_2, 3)); - - - const svint32_t scales_sv_3 = svreinterpret_s32_u32(svand_u32_m(svptrue_b32(), svld1ub_u32(svptrue_b32(), sc+12), m4s)); - - q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 4), m3s)); - q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; - - sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 0)); - - q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 4), m3s)); - q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; - - sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 1)); - - - - q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_2, 6), m3s)); - q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; - - sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 2)); - - q2bytes_sv = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q2bits_4, 6), m3s)); - q8bytes_sv = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; - - sumi1 = svmla_s32_m(svptrue_b32(), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), svdup_lane_s32(scales_sv_3, 3)); - } - acc_sum = svmla_f32_m(svptrue_b32(), acc_sum, svcvt_f32_s32_x(svptrue_b32(), sumi1), d_broad); - } - *s = svaddv_f32(svptrue_b32(), acc_sum); - break; - - case 256: - case 512: - for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - svfloat32_t d_broad = svdup_n_f32((float32_t)d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - svfloat32_t dmin_broad = svdup_n_f32((float32_t)dmin); - - const uint8_t * GGML_RESTRICT q2 = x[i].qs; - const int8_t * GGML_RESTRICT q8_sv = y[i].qs; - const uint8_t * GGML_RESTRICT sc = x[i].scales; - - const svuint32_t mins_and_scales_sve = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc); sc += 8; - const svint32_t scales_sv = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, m4s)); - const svint32_t mins_sv_1 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve, 4)); - svint32_t q8sums_sv_1 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums); - - const svuint32_t mins_and_scales_sve_1 = svld1ub_u32(svptrue_pat_b32(SV_VL8), sc); - const svint32_t scales_sv_1 = svreinterpret_s32_u32(svand_u32_m(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, m4s)); - const svint32_t mins_sv_2 = svreinterpret_s32_u32(svlsr_n_u32_x(svptrue_pat_b32(SV_VL8), mins_and_scales_sve_1, 4)); - - svint32_t q8sums_sv_2 = svld1sh_s32(svptrue_pat_b32(SV_VL8), y[i].bsums+8); - - svfloat32_t temp = svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_1, q8sums_sv_1), svmul_s32_x(svptrue_pat_b32(SV_VL8), mins_sv_2, q8sums_sv_2))); - - acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, temp, dmin_broad); - - svint32_t sumi1 = svdup_n_s32(0); - - { - const svuint8_t q2bits_1 = svld1_u8(svptrue_pat_b8(SV_VL32), q2); - svint8_t q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_1, m3s)); - svint8_t q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; - - svint32_t scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 0), svdup_lane_s32(scales_sv, 1)); - sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1); - - q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 2), m3s)); - q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; - - svint32_t scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 2), svdup_lane_s32(scales_sv, 3)); - sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(svdup_n_s32(0), q2bytes_sv, q8bytes_sv), scale_2); - - q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 4), m3s)); - q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; - - scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv, 4), svdup_lane_s32(scales_sv, 5)); - sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1); - - q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_1, 6), m3s)); - q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; - - scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv, 6), svdup_lane_s32(scales_sv, 7)); - sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2); - - q2 += 32; - - const svuint8_t q2bits_2 = svld1_u8(svptrue_pat_b8(SV_VL32), q2); - q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q2bits_2, m3s)); - q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; - - scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 0), svdup_lane_s32(scales_sv_1, 1)); - sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1); - - q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 2), m3s)); - q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; - - scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 2), svdup_lane_s32(scales_sv_1, 3)); - sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2); - - q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 4), m3s)); - q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; - - scale_1 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 4), svdup_lane_s32(scales_sv_1, 5)); - sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_1); - - q2bytes_sv = svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q2bits_2, 6), m3s)); - q8bytes_sv = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; - - scale_2 = svsel(pred_s32, svdup_lane_s32(scales_sv_1, 6), svdup_lane_s32(scales_sv_1, 7)); - sumi1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(vzero_sv, q2bytes_sv, q8bytes_sv), scale_2); - } - acc_sum = svmla_f32_m(svptrue_pat_b32(SV_VL8), acc_sum, svcvt_f32_s32_x(svptrue_pat_b32(SV_VL8), sumi1), d_broad); - } - *s = svaddv_f32(svptrue_pat_b32(SV_VL8), acc_sum); - break; - - default: - assert(false && "Unsupported vector length"); - break; - } - -#elif __ARM_NEON - const uint8x16_t m3 = vdupq_n_u8(0x3); - const uint8x16_t m4 = vdupq_n_u8(0xF); - - const int32x4_t vzero = vdupq_n_s32(0); - - ggml_int8x16x2_t q2bytes; - uint8_t aux[16]; - - float sum = 0; - - for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - const uint8_t * GGML_RESTRICT q2 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - const uint8_t * GGML_RESTRICT sc = x[i].scales; - - const uint8x16_t mins_and_scales = vld1q_u8(sc); - const uint8x16_t scales = vandq_u8(mins_and_scales, m4); - vst1q_u8(aux, scales); - - const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); - const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums); - const ggml_int16x8x2_t mins16 = {{vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}}; - const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])), - vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0]))); - const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])), - vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1]))); - sum += dmin * vaddvq_s32(vaddq_s32(s0, s1)); - - int isum = 0; - int is = 0; - -// We use this macro instead of a function call because for some reason -// the code runs 2-3% slower, even if the function is declared inline -#define MULTIPLY_ACCUM_WITH_SCALE(index)\ - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\ - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)]; - -#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\ - q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\ - q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\ - q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\ - MULTIPLY_ACCUM_WITH_SCALE((index)); - - for (int j = 0; j < QK_K/128; ++j) { - const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32; - - ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; - q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3)); - q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3)); - - MULTIPLY_ACCUM_WITH_SCALE(0); - - SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2); - SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4); - SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6); - - is += 8; - } - - sum += d * isum; - } - - *s = sum; - -#elif defined __AVX2__ - - const __m256i m3 = _mm256_set1_epi8(3); - const __m128i m4 = _mm_set1_epi8(0xF); - - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - const uint8_t * GGML_RESTRICT q2 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); - const __m128i scales8 = _mm_and_si128(mins_and_scales, m4); - const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); - const __m256i mins = _mm256_cvtepi8_epi16(mins8); - const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums)); - - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc); - - const __m256i all_scales = _mm256_cvtepi8_epi16(scales8); - const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); - const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); - const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; - - __m256i sumi = _mm256_setzero_si256(); - - for (int j = 0; j < QK_K/128; ++j) { - - const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32; - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - - const __m256i q2_0 = _mm256_and_si256(q2bits, m3); - const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3); - const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3); - const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3); - - __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0); - __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1); - __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2); - __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3); - - p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0); - p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1); - p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2); - p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3); - - p0 = _mm256_add_epi32(p0, p1); - p2 = _mm256_add_epi32(p2, p3); - - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2)); - } - - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); - - } - - *s = hsum_float_8(acc); - -#elif defined __AVX__ - - const __m128i m3 = _mm_set1_epi8(0x3); - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i m2 = _mm_set1_epi8(0x2); - - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - const uint8_t * GGML_RESTRICT q2 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - // load mins and scales from block_q2_K.scales[QK_K/16] - const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); - const __m128i scales16 = _mm_and_si128(mins_and_scales, m4); - const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); - const __m128i mins_0 = _mm_cvtepi8_epi16(mins16); - const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16)); - - // summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2 - const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0])); - const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8])); - - // sumf += -dmin * summs in 32bits*8 - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc); - - const __m128i scales_0 = _mm_cvtepi8_epi16(scales16); - const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16)); - const __m128i scales[2] = { scales_0, scales_1 }; - - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); - - for (int j = 0; j < QK_K/128; ++j) { - - // load Q8 quants int8*16*8 from block_q8_K.qs[QK_K] - const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - - // load 2bits*16*8 from block_q2_K.qs[QK_K/4] - __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; - const __m128i q2_0 = _mm_and_si128(q2bits, m3); - const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); - const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); - const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); - q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; - const __m128i q2_1 = _mm_and_si128(q2bits, m3); - const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); - const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); - const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); - - // isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8 - __m128i p0 = _mm_maddubs_epi16(q2_0, q8_0); - __m128i p1 = _mm_maddubs_epi16(q2_1, q8_1); - __m128i p2 = _mm_maddubs_epi16(q2_2, q8_2); - __m128i p3 = _mm_maddubs_epi16(q2_3, q8_3); - __m128i p4 = _mm_maddubs_epi16(q2_4, q8_4); - __m128i p5 = _mm_maddubs_epi16(q2_5, q8_5); - __m128i p6 = _mm_maddubs_epi16(q2_6, q8_6); - __m128i p7 = _mm_maddubs_epi16(q2_7, q8_7); - - // isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8 - __m128i shuffle = _mm_set1_epi16(0x0100); - p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0); - shuffle = _mm_add_epi16(shuffle, m2); - p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1); - shuffle = _mm_add_epi16(shuffle, m2); - p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2); - shuffle = _mm_add_epi16(shuffle, m2); - p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3); - shuffle = _mm_add_epi16(shuffle, m2); - p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4); - shuffle = _mm_add_epi16(shuffle, m2); - p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5); - shuffle = _mm_add_epi16(shuffle, m2); - p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6); - shuffle = _mm_add_epi16(shuffle, m2); - p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7); - - p0 = _mm_add_epi32(p0, p1); - p2 = _mm_add_epi32(p2, p3); - p4 = _mm_add_epi32(p4, p5); - p6 = _mm_add_epi32(p6, p7); - - // isum in 32bits*4*2 - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6)); - } - - // sumf += dall * isum - dmin * summs in 32bits - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc); - } - - *s = hsum_float_8(acc); - -#elif defined __wasm_simd128__ - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; - - // Vectorized summs calculation - v128_t summs_vec = wasm_i32x4_splat(0); - { - v128_t sc_vec = wasm_v128_load(sc); - v128_t sc_upper = wasm_u8x16_shr(sc_vec, 4); - - v128_t sc_low = wasm_u16x8_extend_low_u8x16(sc_upper); - v128_t sc_high = wasm_u16x8_extend_high_u8x16(sc_upper); - - v128_t bsums1 = wasm_v128_load(&y[i].bsums[0]); - v128_t bsums2 = wasm_v128_load(&y[i].bsums[8]); - - summs_vec = wasm_i32x4_add( - wasm_i32x4_add(wasm_i32x4_dot_i16x8(sc_low, bsums1), - wasm_i32x4_dot_i16x8(sc_high, bsums2)), - summs_vec - ); - - summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 2, 3, 0, 1)); - summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 1, 0, 3, 2)); - } - int32_t summs = wasm_i32x4_extract_lane(summs_vec, 0); - - // Vectorized isum calculation - int32_t isum = 0; - const uint8_t * sc_ptr = sc; - const int k_iters = QK_K/128; - - for (int k = 0; k < k_iters; ++k) { - v128_t isum_vec = wasm_i32x4_splat(0); - int shift = 0; - - for (int j = 0; j < 4; ++j) { - const int d0 = (sc_ptr[0] & 0xF); - const int d1 = (sc_ptr[1] & 0xF); - sc_ptr += 2; - - // Process first 16 elements - v128_t q2_0 = wasm_v128_load(q2); - v128_t q8_0 = wasm_v128_load(q8); - v128_t q2_shift_0 = wasm_u8x16_shr(q2_0, shift); - v128_t q2_bits_0 = wasm_v128_and(q2_shift_0, wasm_i8x16_splat(0x03)); - - // Process next 16 elements - v128_t q2_1 = wasm_v128_load(q2 + 16); - v128_t q8_1 = wasm_v128_load(q8 + 16); - v128_t q2_shift_1 = wasm_u8x16_shr(q2_1, shift); - v128_t q2_bits_1 = wasm_v128_and(q2_shift_1, wasm_i8x16_splat(0x03)); - - // Calculate dot products - v128_t p0 = wasm_i32x4_dot_i16x8( - wasm_i16x8_extend_low_i8x16(q8_0), - wasm_i16x8_extend_low_i8x16(q2_bits_0) - ); - v128_t p1 = wasm_i32x4_dot_i16x8( - wasm_i16x8_extend_high_i8x16(q8_0), - wasm_i16x8_extend_high_i8x16(q2_bits_0) - ); - v128_t p2 = wasm_i32x4_dot_i16x8( - wasm_i16x8_extend_low_i8x16(q8_1), - wasm_i16x8_extend_low_i8x16(q2_bits_1) - ); - v128_t p3 = wasm_i32x4_dot_i16x8( - wasm_i16x8_extend_high_i8x16(q8_1), - wasm_i16x8_extend_high_i8x16(q2_bits_1) - ); - - // Accumulate scaled results - v128_t scaled = wasm_i32x4_add( - wasm_i32x4_mul(wasm_i32x4_add(p0, p1), wasm_i32x4_splat(d0)), - wasm_i32x4_mul(wasm_i32x4_add(p2, p3), wasm_i32x4_splat(d1)) - ); - - isum_vec = wasm_i32x4_add(isum_vec, scaled); - q8 += 32; - shift += 2; - } - q2 += 32; - - // Horizontal sum of isum_vec - isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 2, 3, 0, 1)); - isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 1, 0, 3, 2)); - isum += wasm_i32x4_extract_lane(isum_vec, 0); - } - - const float dall = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; - sumf += dall * isum - dmin * summs; - } - - *s = sumf; - -#elif defined __riscv_v_intrinsic - - const int vector_length = __riscv_vlenb() * 8; - float sumf = 0; - - uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }; - uint8_t atmp[16]; - - switch (vector_length) { - case 256: - for (int i = 0; i < nb; ++i) { - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; - - const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - size_t vl = 16; - - vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); - vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); - - vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); - - vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); - vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); - vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); - vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); - vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - - sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); - - vl = 32; - - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); - - uint8_t is = 0; - int isum = 0; - - for (int j = 0; j < QK_K / 128; ++j) { - // load Q2 - vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); - - vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); - vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl); - vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl); - vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl); - - // duplicate scale elements for product - vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl); - vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl); - vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl); - vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl); - - vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); - vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); - vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); - vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); - - // load Q8 - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl); - vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl); - vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl); - - vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); - vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); - vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); - vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); - - vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); - - isum += __riscv_vmv_x_s_i32m1_i32(isum1); - - q2 += 32; - q8 += 128; - is = 8; - } - - sumf += dall * isum; - } - break; - case 128: - for (int i = 0; i < nb; ++i) { - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; - const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - uint8_t *patmp = atmp; - int vsums; - int tmp; - __asm__ __volatile__( - "vsetivli zero, 16, e8, m1\n\t" - "vmv.v.x v8, zero\n\t" - "vle8.v v1, (%[sc])\n\t" - "vand.vi v0, v1, 0xF\n\t" - "vsrl.vi v1, v1, 4\n\t" - "vse8.v v0, (%[scale])\n\t" - "vsetivli zero, 16, e16, m2\n\t" - "vle16.v v2, (%[bsums])\n\t" - "vzext.vf2 v0, v1\n\t" - "vwmul.vv v4, v0, v2\n\t" - "vsetivli zero, 16, e32, m4\n\t" - "vredsum.vs v8, v4, v8\n\t" - "vmv.x.s %[vsums], v8" - : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums) - : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); - sumf += dmin * vsums; - int isum = 0; - - for (int j = 0; j < QK_K/128; ++j) { - __asm__ __volatile__( - "vsetvli zero, %[vl32], e8, m2\n\t" - "vle8.v v0, (%[q2])\n\t" - "vsrl.vi v2, v0, 2\n\t" - "vsrl.vi v4, v0, 4\n\t" - "vsrl.vi v6, v0, 6\n\t" - "vand.vi v0, v0, 0x3\n\t" - "vand.vi v2, v2, 0x3\n\t" - "vand.vi v4, v4, 0x3\n\t" - "vsetvli zero, %[vl128], e8, m8\n\t" - "vle8.v v8, (%[q8])\n\t" - "vsetvli zero, %[vl64], e8, m4\n\t" - "vwmul.vv v16, v0, v8\n\t" - "vwmul.vv v24, v4, v12\n\t" - "vsetivli zero, 16, e16, m2\n\t" - "vmv.v.x v0, zero\n\t" - "vwredsum.vs v10, v16, v0\n\t" - "vwredsum.vs v9, v18, v0\n\t" - "vwredsum.vs v8, v20, v0\n\t" - "vwredsum.vs v7, v22, v0\n\t" - "vwredsum.vs v11, v24, v0\n\t" - "vwredsum.vs v12, v26, v0\n\t" - "vwredsum.vs v13, v28, v0\n\t" - "vwredsum.vs v14, v30, v0\n\t" - "vsetivli zero, 4, e32, m1\n\t" - "vslideup.vi v10, v9, 1\n\t" - "vslideup.vi v8, v7, 1\n\t" - "vslideup.vi v11, v12, 1\n\t" - "vslideup.vi v13, v14, 1\n\t" - "vslideup.vi v10, v8, 2\n\t" - "vslideup.vi v11, v13, 2\n\t" - "vsetivli zero, 8, e32, m2\n\t" - "vle8.v v15, (%[scale])\n\t" - "vzext.vf4 v12, v15\n\t" - "vmul.vv v10, v10, v12\n\t" - "vredsum.vs v0, v10, v0\n\t" - "vmv.x.s %[tmp], v0\n\t" - "add %[isum], %[isum], %[tmp]" - : [tmp] "=&r" (tmp), [isum] "+&r" (isum) - : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8) - , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); - q2 += 32; q8 += 128; patmp += 8; - } - - sumf += dall * isum; - } - break; - default: - assert(false && "Unsupported vector length"); - break; - } - - *s = sumf; - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0x3); - const vector signed char lowScaleMask = vec_splats((signed char)0xF); - const vector int v0 = vec_splats((int32_t)0); - const vector unsigned char v2 = vec_splats((unsigned char)0x2); - const vector unsigned char v6 = vec_splats((unsigned char)0x6); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin)); - vector float vdmin = vec_mul(vxmin, vyd); - - vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); - vector signed short q8ysums1 = vec_xl(16, y[i].bsums); - - vector signed char q2xmins = (vector signed char)vec_xl( 0, x[i].scales); - vector signed char vscales = vec_and(q2xmins, lowScaleMask); - - q2xmins = vec_sr(q2xmins, v4); - vector signed short q2xmins0 = vec_unpackh(q2xmins); - vector signed short q2xmins1 = vec_unpackl(q2xmins); - - vector signed int prod0 = vec_mule(q2xmins0, q8ysums0); - vector signed int prod1 = vec_mulo(q2xmins0, q8ysums0); - vector signed int prod2 = vec_mule(q2xmins1, q8ysums1); - vector signed int prod3 = vec_mulo(q2xmins1, q8ysums1); - - vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0); - vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1); - vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); - vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - vector signed int vsumi4 = v0; - vector signed int vsumi5 = v0; - vector signed int vsumi6 = v0; - vector signed int vsumi7 = v0; - - const uint8_t * GGML_RESTRICT q2 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - for (int j = 0; j < QK_K/128; ++j) { - __builtin_prefetch(q2, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector signed char qxs0 = (vector signed char)vec_xl( 0, q2); - vector signed char qxs1 = (vector signed char)vec_xl(16, q2); - q2 += 32; - - vector unsigned char q2x00 = (vector unsigned char)vec_and(qxs0, lowMask); - vector unsigned char q2x01 = (vector unsigned char)vec_and(vec_sr(qxs0, v2), lowMask); - vector unsigned char q2x02 = (vector unsigned char)vec_and(vec_sr(qxs0, v4), lowMask); - vector unsigned char q2x03 = (vector unsigned char)vec_and(vec_sr(qxs0, v6), lowMask); - vector unsigned char q2x10 = (vector unsigned char)vec_and(qxs1, lowMask); - vector unsigned char q2x11 = (vector unsigned char)vec_and(vec_sr(qxs1, v2), lowMask); - vector unsigned char q2x12 = (vector unsigned char)vec_and(vec_sr(qxs1, v4), lowMask); - vector unsigned char q2x13 = (vector unsigned char)vec_and(vec_sr(qxs1, v6), lowMask); - - vector signed char q8y00 = vec_xl( 0, q8); - vector signed char q8y10 = vec_xl( 16, q8); - vector signed char q8y01 = vec_xl( 32, q8); - vector signed char q8y11 = vec_xl( 48, q8); - vector signed char q8y02 = vec_xl( 64, q8); - vector signed char q8y12 = vec_xl( 80, q8); - vector signed char q8y03 = vec_xl( 96, q8); - vector signed char q8y13 = vec_xl(112, q8); - q8 += 128; - - vector signed int qv0 = vec_msum(q8y00, q2x00, v0); - vector signed int qv1 = vec_msum(q8y01, q2x01, v0); - vector signed int qv2 = vec_msum(q8y02, q2x02, v0); - vector signed int qv3 = vec_msum(q8y03, q2x03, v0); - vector signed int qv4 = vec_msum(q8y10, q2x10, v0); - vector signed int qv5 = vec_msum(q8y11, q2x11, v0); - vector signed int qv6 = vec_msum(q8y12, q2x12, v0); - vector signed int qv7 = vec_msum(q8y13, q2x13, v0); - - vector signed short vscales_07 = vec_unpackh(vscales); - vector signed int vscales_03 = vec_unpackh(vscales_07); - vector signed int vscales_47 = vec_unpackl(vscales_07); - vector signed int vs0 = vec_splat(vscales_03, 0); - vector signed int vs1 = vec_splat(vscales_03, 1); - vector signed int vs2 = vec_splat(vscales_03, 2); - vector signed int vs3 = vec_splat(vscales_03, 3); - vector signed int vs4 = vec_splat(vscales_47, 0); - vector signed int vs5 = vec_splat(vscales_47, 1); - vector signed int vs6 = vec_splat(vscales_47, 2); - vector signed int vs7 = vec_splat(vscales_47, 3); - vscales = vec_sld(vscales, vscales, 8); - - vsumi0 = vec_add(vec_mul(qv0, vs0), vsumi0); - vsumi1 = vec_add(vec_mul(qv1, vs2), vsumi1); - vsumi2 = vec_add(vec_mul(qv2, vs4), vsumi2); - vsumi3 = vec_add(vec_mul(qv3, vs6), vsumi3); - vsumi4 = vec_add(vec_mul(qv4, vs1), vsumi4); - vsumi5 = vec_add(vec_mul(qv5, vs3), vsumi5); - vsumi6 = vec_add(vec_mul(qv6, vs5), vsumi6); - vsumi7 = vec_add(vec_mul(qv7, vs7), vsumi7); - } - - vsumi0 = vec_add(vsumi0, vsumi4); - vsumi1 = vec_add(vsumi1, vsumi5); - vsumi2 = vec_add(vsumi2, vsumi6); - vsumi3 = vec_add(vsumi3, vsumi7); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#elif defined __loongarch_asx - - __m256 acc = (__m256)__lasx_xvldi(0); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - const uint8_t * GGML_RESTRICT q2 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - const __m128i mins_and_scales128 = __lsx_vld((const __m128i*)x[i].scales, 0); - const __m128i scales128 = __lsx_vandi_b(mins_and_scales128, 0xf); - const __m256i mins = lasx_ext8_16(__lsx_vsrli_b(mins_and_scales128, 4)); - const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0)); - - acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc); - - const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; - const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask)); - - __m256i sumi = __lasx_xvldi(0); - - for (int j = 0; j < QK_K/128; ++j) { - - const __m256i q2bits = __lasx_xvld((const __m256i*)q2, 0); q2 += 32; - - const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - - const __m256i q2_0 = __lasx_xvandi_b(q2bits, 3); - const __m256i q2_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 2), 3); - const __m256i q2_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 4), 3); - const __m256i q2_3 = __lasx_xvsrli_b(q2bits, 6); - - __m256i p0 = lasx_madd_h_b(q2_0, q8_0); - __m256i p1 = lasx_madd_h_b(q2_1, q8_1); - __m256i p2 = lasx_madd_h_b(q2_2, q8_2); - __m256i p3 = lasx_madd_h_b(q2_3, q8_3); - - p0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p0); - p1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p1); - p2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p2); - p3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p3); - - p0 = __lasx_xvadd_w(p0, p1); - p2 = __lasx_xvadd_w(p2, p3); - - sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p0, p2)); - } - - acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc); - - } - - *s = hsum_float_8(acc); - -#else - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; - - int summs = 0; - for (int j = 0; j < 16; ++j) { - summs += y[i].bsums[j] * (sc[j] >> 4); - } - - const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - int isum = 0; - int is = 0; - int d; - for (int k = 0; k < QK_K/128; ++k) { - int shift = 0; - for (int j = 0; j < 4; ++j) { - d = sc[is++] & 0xF; - int isuml = 0; - for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); - isum += d * isuml; - d = sc[is++] & 0xF; - isuml = 0; - for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); - isum += d * isuml; - shift += 2; - q8 += 32; - } - q2 += 32; - } - sumf += dall * isum - dmin * summs; - } - *s = sumf; -#endif -} - -void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const uint32_t kmask1 = 0x03030303; - const uint32_t kmask2 = 0x0f0f0f0f; - - const block_q3_K * GGML_RESTRICT x = vx; - const block_q8_K * GGML_RESTRICT y = vy; - - const int nb = n / QK_K; - -#if defined(__ARM_FEATURE_SVE) - - uint32_t aux[3]; - uint32_t utmp[4]; - - const int8_t m32 = 32; - const int vector_length = svcntb()*8; - const svuint8_t m3b_sv = svdup_n_u8(0x3); - const svint32_t vzero_sv = svdup_n_s32(0); - - const svuint8_t m0_sv = svdup_n_u8(1); - const svuint8_t m1_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 1); - const svuint8_t m2_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 2); - const svuint8_t m3_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 3); - - float sum = 0; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * GGML_RESTRICT q3_sv = x[i].qs; - const uint8_t * GGML_RESTRICT qh_sv = x[i].hmask; - const int8_t * GGML_RESTRICT q8_sv = y[i].qs; - - // Set up scales - memcpy(aux, x[i].scales, 12); - utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); - utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); - utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); - utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); - - int8_t * scale = (int8_t *)utmp; - - for (int j = 0; j < 16; ++j) scale[j] -= m32; - - switch (vector_length) { - case 128: - { - svuint8_t qhbits_sv_1 = svld1_u8(svptrue_b8(), qh_sv); - svuint8_t qhbits_sv_2 = svld1_u8(svptrue_b8(), qh_sv+16); - svuint8_t q3h_sv; - - svint32_t sumi1_1 = svdup_n_s32(0); - svint8_t q3bytes_sv; - - for (int j = 0; j < QK_K/128; ++j) { - - const svuint8_t q3bits_sv = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16; - const svuint8_t q3bits_sv_1 = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16; - svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; - svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; - - q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_1), 2); - q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv)); - - sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0])); - - q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_2), 2); - q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv_1, m3b_sv)), svreinterpret_s8_u8(q3h_sv)); - - sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1])); - - q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; - q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; - - q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_1), 1); - q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); - - sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2])); - - q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_2), 1); - q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); - - sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3])); - - - scale += 4; - q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; - q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; - - q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_1); - q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); - - sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0])); - - q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_2); - q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); - - sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1])); - - - q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; - q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16; - - q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_1), 1); - q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); - - sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2])); - - q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_2), 1); - q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); - - sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3])); - - if (j == 0) { - qhbits_sv_1 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_1, 4); - qhbits_sv_2 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_2, 4); - } - - scale += 4; - } - - sum += d * (svaddv_s32(svptrue_b32(), sumi1_1)); - } break; - case 256: - case 512: - { - svuint8_t qhbits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), qh_sv); - svuint8_t q3h_sv; - - svint32_t sumi1_1 = svdup_n_s32(0); - svint8_t q3bytes_sv; - - for (int j = 0; j < QK_K/128; ++j) { - - const svuint8_t q3bits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), q3_sv); q3_sv += 32; - svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; - svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; - - q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m0_sv, qhbits_sv), 2); - q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv)); - - - svint32_t scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1])); - sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1); - - q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m1_sv, qhbits_sv), 1); - q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); - - scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3])); - sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1); - - scale += 4; - q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; - q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32; - - q3h_sv = svbic_u8_x(svptrue_pat_b8(SV_VL32), m2_sv, qhbits_sv); - q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); - - scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1])); - sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1); - - q3h_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m3_sv, qhbits_sv), 1); - q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv)); - - scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3])); - sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1); - - if (j == 0) { - qhbits_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), qhbits_sv, 4); - } - - scale += 4; - } - - sum += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), sumi1_1)); - } break; - default: - assert(false && "Unsupported vector length"); - break; - } - } - *s = sum; - -#elif __ARM_NEON - - uint32_t aux[3]; - uint32_t utmp[4]; - - const uint8x16_t m3b = vdupq_n_u8(0x3); - const int32x4_t vzero = vdupq_n_s32(0); - - const uint8x16_t m0 = vdupq_n_u8(1); - const uint8x16_t m1 = vshlq_n_u8(m0, 1); - const uint8x16_t m2 = vshlq_n_u8(m0, 2); - const uint8x16_t m3 = vshlq_n_u8(m0, 3); - const int8_t m32 = 32; - - ggml_int8x16x4_t q3bytes; - - float sum = 0; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * GGML_RESTRICT q3 = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].hmask; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); - - ggml_uint8x16x4_t q3h; - - int32_t isum = 0; - - // Set up scales - memcpy(aux, x[i].scales, 12); - utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); - utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); - utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); - utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); - - int8_t * scale = (int8_t *)utmp; - for (int j = 0; j < 16; ++j) scale[j] -= m32; - - for (int j = 0; j < QK_K/128; ++j) { - - const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32; - const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64; - const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64; - - q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2); - q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2); - q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1); - q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1); - - q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0])); - q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1])); - q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2])); - q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3])); - - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0]; - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1]; - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2]; - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3]; - - scale += 4; - - q3h.val[0] = vbicq_u8(m2, qhbits.val[0]); - q3h.val[1] = vbicq_u8(m2, qhbits.val[1]); - q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1); - q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1); - - q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0])); - q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1])); - q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2])); - q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3])); - - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0]; - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1]; - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2]; - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3]; - - scale += 4; - - if (j == 0) { - qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4); - qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4); - } - - } - sum += d * isum; - - } - - *s = sum; - -#elif defined __AVX2__ - - const __m256i m3 = _mm256_set1_epi8(3); - const __m256i mone = _mm256_set1_epi8(1); - const __m128i m32 = _mm_set1_epi8(32); - - __m256 acc = _mm256_setzero_ps(); - - uint32_t aux[3]; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * GGML_RESTRICT q3 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - // Set up scales - memcpy(aux, x[i].scales, 12); - __m128i scales128 = _mm_set_epi32( - ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), - ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), - (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), - (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); - scales128 = _mm_sub_epi8(scales128, m32); - const __m256i all_scales = _mm256_cvtepi8_epi16(scales128); - const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); - const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); - const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; - - // high bit - const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask); - - // integer accumulator - __m256i sumi = _mm256_setzero_si256(); - - int bit = 0; - int is = 0; - - for (int j = 0; j < QK_K/128; ++j) { - // load low 2 bits - const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32; - - // prepare low and high bits - const __m256i q3l_0 = _mm256_and_si256(q3bits, m3); - const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); - ++bit; - - const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3); - const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); - ++bit; - - const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3); - const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); - ++bit; - - const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3); - const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); - ++bit; - - // load Q8 quants - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - - // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, - // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, - // and 2 if the high bit was set) - __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); - __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); - __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2); - __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3); - - __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); - __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); - __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2); - __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3); - - p16_0 = _mm256_sub_epi16(p16_0, q8s_0); - p16_1 = _mm256_sub_epi16(p16_1, q8s_1); - p16_2 = _mm256_sub_epi16(p16_2, q8s_2); - p16_3 = _mm256_sub_epi16(p16_3, q8s_3); - - // multiply with scales - p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0); - p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1); - p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2); - p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3); - - // accumulate - p16_0 = _mm256_add_epi32(p16_0, p16_1); - p16_2 = _mm256_add_epi32(p16_2, p16_3); - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2)); - - } - - // multiply with block scale and accumulate - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); - - } - - *s = hsum_float_8(acc); - -#elif defined __AVX__ - - const __m128i m3 = _mm_set1_epi8(3); - const __m128i mone = _mm_set1_epi8(1); - const __m128i m32 = _mm_set1_epi8(32); - const __m128i m2 = _mm_set1_epi8(2); - - __m256 acc = _mm256_setzero_ps(); - - const uint32_t *aux; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * GGML_RESTRICT q3 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - // Set up scales - aux = (const uint32_t *)x[i].scales; - __m128i scales128 = _mm_set_epi32( - ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), - ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), - (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), - (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); - scales128 = _mm_sub_epi8(scales128, m32); - const __m128i scales_0 = _mm_cvtepi8_epi16(scales128); - const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128)); - const __m128i scales[2] = { scales_0, scales_1 }; - - // high bit *128*2 from block_q3_K.hmask[QK_K/8] - const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]); - const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]); - - // integer accumulator - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); - - for (int j = 0; j < QK_K/128; ++j) { - // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4] - const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; - const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; - - // prepare low and high bits - const int bit = j << 2; - - const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3); - const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3); - const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2); - const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2); - - const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3); - const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3); - const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2); - const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2); - - const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3); - const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3); - const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2); - const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2); - - const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3); - const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3); - const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2); - const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2); - - // load Q8 quants from block_q8_K.qs[QK_K] - const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - - // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, - // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, - // and 2 if the high bit was set) - __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0); - __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1); - __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2); - __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3); - __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4); - __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5); - __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6); - __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7); - - __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0); - __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1); - __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2); - __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3); - __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4); - __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5); - __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6); - __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7); - - p16_0 = _mm_sub_epi16(p16_0, q8s_0); - p16_1 = _mm_sub_epi16(p16_1, q8s_1); - p16_2 = _mm_sub_epi16(p16_2, q8s_2); - p16_3 = _mm_sub_epi16(p16_3, q8s_3); - p16_4 = _mm_sub_epi16(p16_4, q8s_4); - p16_5 = _mm_sub_epi16(p16_5, q8s_5); - p16_6 = _mm_sub_epi16(p16_6, q8s_6); - p16_7 = _mm_sub_epi16(p16_7, q8s_7); - - // multiply with scales - __m128i shuffle = _mm_set1_epi16(0x0100); - p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0); - shuffle = _mm_add_epi16(shuffle, m2); - p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1); - shuffle = _mm_add_epi16(shuffle, m2); - p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2); - shuffle = _mm_add_epi16(shuffle, m2); - p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3); - shuffle = _mm_add_epi16(shuffle, m2); - p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4); - shuffle = _mm_add_epi16(shuffle, m2); - p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5); - shuffle = _mm_add_epi16(shuffle, m2); - p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6); - shuffle = _mm_add_epi16(shuffle, m2); - p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7); - - // accumulate - p16_0 = _mm_add_epi32(p16_0, p16_1); - p16_2 = _mm_add_epi32(p16_2, p16_3); - p16_4 = _mm_add_epi32(p16_4, p16_5); - p16_6 = _mm_add_epi32(p16_6, p16_7); - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6)); - - } - - // multiply with block scale and accumulate - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc); - - } - - *s = hsum_float_8(acc); - -#elif defined __wasm_simd128__ - int8_t aux8[QK_K]; - float sums[8] = {0}; - uint32_t auxs[4]; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * GGML_RESTRICT q3 = x[i].qs; - const uint8_t * GGML_RESTRICT hm = x[i].hmask; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - // Process blocks with SIMD - int8_t * a = aux8; - uint8_t m = 1; - for (int j = 0; j < QK_K; j += 128) { - for (int shift = 0; shift <= 6; shift += 2) { - v128_t v_m = wasm_i8x16_splat(m); - for (int l = 0; l < 32; l += 16) { - v128_t v_q3 = wasm_v128_load(q3 + l); - v128_t v_shift = wasm_i8x16_shr(v_q3, shift); - v128_t v_low2 = wasm_v128_and(v_shift, wasm_i8x16_splat(0x03)); - - v128_t v_hm = wasm_v128_load(hm + l); - v128_t v_mask = wasm_v128_and(v_hm, v_m); - v_mask = wasm_i8x16_ne(v_mask, wasm_i8x16_splat(0)); - - v_low2 = wasm_i8x16_sub(v_low2, wasm_v128_and(wasm_i8x16_splat(4), wasm_v128_not(v_mask))); - wasm_v128_store(a + l, v_low2); - } - a += 32; - m <<= 1; - } - q3 += 32; - } - - // Extract scales - memcpy(auxs, x[i].scales, 12); - uint32_t tmp = auxs[2]; - auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); - auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); - auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); - auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); - const int8_t * scales = (const int8_t *)auxs; - - // SIMD dot product with register accumulators - v128_t v_acc0 = wasm_i32x4_splat(0); - v128_t v_acc1 = wasm_i32x4_splat(0); - a = aux8; - for (int j = 0; j < QK_K/16; ++j) { - const v128_t v_scale = wasm_i16x8_splat(scales[j] - 32); - - // Process 16 elements per iteration - for (int k = 0; k < 2; ++k) { - const v128_t v_q8 = wasm_i16x8_load8x8(q8); - const v128_t v_a = wasm_i16x8_load8x8(a); - - v128_t v_prod = wasm_i16x8_mul(v_q8, v_a); - v_prod = wasm_i16x8_mul(v_prod, v_scale); - - v_acc0 = wasm_i32x4_add(v_acc0, wasm_i32x4_extend_low_i16x8(v_prod)); - v_acc1 = wasm_i32x4_add(v_acc1, wasm_i32x4_extend_high_i16x8(v_prod)); - - q8 += 8; - a += 8; - } - } - - // Accumulate results - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const v128_t v_d = wasm_f32x4_splat(d); - v128_t v_sum = wasm_f32x4_add( - wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc0), v_d), - wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc1), v_d) - ); - - // Accumulate into sums vector - wasm_v128_store(sums, wasm_f32x4_add(wasm_v128_load(sums), v_sum)); - } - - // Horizontal sum - v128_t v_sum = wasm_f32x4_add(wasm_v128_load(sums), wasm_v128_load(sums + 4)); - sumf = wasm_f32x4_extract_lane(v_sum, 0) + - wasm_f32x4_extract_lane(v_sum, 1) + - wasm_f32x4_extract_lane(v_sum, 2) + - wasm_f32x4_extract_lane(v_sum, 3); - - *s = sumf; - -#elif defined __riscv_v_intrinsic - - uint32_t aux[3]; - uint32_t utmp[4]; - - const int vector_length = __riscv_vlenb() * 8; - float sumf = 0; - - switch (vector_length) { - case 256: - for (int i = 0; i < nb; ++i) { - - const uint8_t * GGML_RESTRICT q3 = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].hmask; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - memcpy(aux, x[i].scales, 12); - utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); - utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); - utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); - utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); - - int8_t * scale = (int8_t *)utmp; - for (int j = 0; j < 16; ++j) scale[j] -= 32; - - - size_t vl = 32; - uint8_t m = 1; - - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); - - int sum_t = 0; - - for (int j = 0; j < QK_K; j += 128) { - - vl = 32; - - // load Q3 - vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); - - vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); - vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); - vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); - vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); - - // compute mask for subtraction - vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); - vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl); - m <<= 1; - - vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); - vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl); - m <<= 1; - - vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); - vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl); - m <<= 1; - - vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); - vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl); - m <<= 1; - - // load Q8 and take product with Q3 - vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); - - vl = 16; - - // retrieve lane to multiply with scale - vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); - vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); - vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); - vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); - vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); - vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); - vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); - vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); - - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); - - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); - - q3 += 32; q8 += 128; scale += 8; - - } - - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - - sumf += d*sum_t; - - } - break; - case 128: - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict qh = x[i].hmask; - const int8_t * restrict q8 = y[i].qs; - - int8_t * scale = (int8_t *)utmp; - int tmp; - __asm__ __volatile__( - "vsetivli zero, 12, e8, m1\n\t" - "vle8.v v0, (%[s6b])\n\t" - "vmv1r.v v2, v0\n\t" - "vsetivli zero, 2, e64, m1\n\t" - "vmv.v.x v9, %[sh]\n\t"\ - "vslidedown.vi v1, v0, 1\n\t" - "vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4} - "vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]} - "vsetivli zero, 4, e32, m1\n\t" - "vid.v v9\n\t" - "vmv.x.s %[tmp], v1\n\t" - "vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6} - "vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]} - "vsrl.vv v4, v1, v9\n\t" - "vsrl.vv v2, v0, v8\n\t" - "vand.vx v5, v4, %[kmask1]\n\t" - "vand.vx v3, v2, %[kmask2]\n\t" - "vsll.vi v6, v5, 4\n\t" - "vor.vv v7, v6, v3\n\t" - "vsetivli zero, 16, e8, m1\n\t" - "vsub.vx v0, v7, %[c]\n\t" - "vse8.v v0, (%[scale])" - : [tmp] "=&r" (tmp) - : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32) - , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); - - uint8_t m = 1; - int isum = 0; - for (int j = 0; j < QK_K; j += 128) { - __asm__ __volatile__( - "vsetvli zero, %[vl32], e8, m2, ta, mu\n\t" - "vle8.v v8, (%[q3])\n\t" - "vsrl.vi v10, v8, 2\n\t" - "vsrl.vi v12, v8, 4\n\t" - "vsrl.vi v14, v8, 6\n\t" - "vand.vi v8, v8, 3\n\t" - "vand.vi v10, v10, 3\n\t" - "vand.vi v12, v12, 3\n\t" - "vle8.v v2, (%[qh])\n\t" - "vand.vx v4, v2, %[m]\n\t" - "slli %[m], %[m], 1\n\t" - "vmseq.vx v0, v4, zero\n\t" - "vadd.vi v8, v8, -4, v0.t\n\t" - "vand.vx v4, v2, %[m]\n\t" - "slli %[m], %[m], 1\n\t" - "vmseq.vx v0, v4, zero\n\t" - "vadd.vi v10, v10, -4, v0.t\n\t" - "vand.vx v4, v2, %[m]\n\t" - "slli %[m], %[m], 1\n\t" - "vmseq.vx v0, v4, zero\n\t" - "vadd.vi v12, v12, -4, v0.t\n\t" - "vand.vx v4, v2, %[m]\n\t" - "slli %[m], %[m], 1\n\t" - "vmseq.vx v0, v4, zero\n\t" - "vadd.vi v14, v14, -4, v0.t\n\t" - "vsetvli zero, %[vl128], e8, m8\n\t" - "vle8.v v0, (%[q8])\n\t" - "vsetvli zero, %[vl64], e8, m4\n\t" - "vwmul.vv v16, v0, v8\n\t" - "vwmul.vv v24, v4, v12\n\t" - "vsetivli zero, 16, e16, m2\n\t" - "vmv.v.x v0, zero\n\t" - "vwredsum.vs v10, v16, v0\n\t" - "vwredsum.vs v9, v18, v0\n\t" - "vwredsum.vs v8, v20, v0\n\t" - "vwredsum.vs v7, v22, v0\n\t" - "vwredsum.vs v11, v24, v0\n\t" - "vwredsum.vs v12, v26, v0\n\t" - "vwredsum.vs v13, v28, v0\n\t" - "vwredsum.vs v14, v30, v0\n\t" - "vsetivli zero, 4, e32, m1\n\t" - "vslideup.vi v10, v9, 1\n\t" - "vslideup.vi v8, v7, 1\n\t" - "vslideup.vi v11, v12, 1\n\t" - "vslideup.vi v13, v14, 1\n\t" - "vslideup.vi v10, v8, 2\n\t" - "vslideup.vi v11, v13, 2\n\t" - "vsetivli zero, 8, e32, m2\n\t"\ - "vle8.v v15, (%[scale])\n\t" - "vsext.vf4 v12, v15\n\t" - "vmul.vv v10, v10, v12\n\t" - "vredsum.vs v0, v10, v0\n\t" - "vmv.x.s %[tmp], v0\n\t" - "add %[isum], %[isum], %[tmp]" - : [tmp] "=&r" (tmp), [m] "+&r" (m), [isum] "+&r" (isum) - : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32) - , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); - q3 += 32; q8 += 128; scale += 8; - } - - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - sumf += d * isum; - } - break; - default: - assert(false && "Unsupported vector length"); - break; - } - - *s = sumf; - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0x3); - const vector signed char lowMask1 = vec_splats((int8_t)0xf); - const vector signed char lowMask2 = vec_splats((int8_t)0x30); - const vector int v0 = vec_splats((int32_t)0); - const vector signed char v1 = vec_splats((signed char)0x1); - const vector unsigned char v2 = vec_splats((unsigned char)0x2); - const vector unsigned char v3 = vec_splats((unsigned char)0x3); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - const vector unsigned char v6 = vec_splats((unsigned char)0x6); - const vector signed char off = vec_splats((signed char)0x20); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - UNUSED(kmask1); - UNUSED(kmask2); - - vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); - vector signed char u1 = vec_and(u0, lowMask1); - vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); - vector signed char u3 = (vector signed char)vec_mergeh((vector signed int)u2, (vector signed int)vec_sr(u2, v2)); - vector signed char u30 = vec_sl(vec_and(u3, lowMask), v4); - vector signed char u31 = vec_and(u3, lowMask2); - - u1 = vec_or(u1, u30); - u2 = vec_or(vec_sr(u0, v4), u31); - - vector signed char vscales = (vector signed char)vec_mergeh((vector signed long long)u1, (vector signed long long)u2); - vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].hmask); - vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].hmask); - - vscales = vec_sub(vscales, off); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - vector signed int vsumi4 = v0; - vector signed int vsumi5 = v0; - vector signed int vsumi6 = v0; - vector signed int vsumi7 = v0; - - const uint8_t * GGML_RESTRICT q3 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - for (int j = 0; j < QK_K/128; ++j) { - __builtin_prefetch(q3, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector signed char qxs0 = (vector signed char)vec_xl( 0, q3); - vector signed char qxs1 = (vector signed char)vec_xl(16, q3); - q3 += 32; - - //the low 2 bits - vector signed char qxs00 = vec_and(qxs0, lowMask); - vector signed char qxs01 = vec_and(vec_sr(qxs0, v2), lowMask); - vector signed char qxs02 = vec_and(vec_sr(qxs0, v4), lowMask); - vector signed char qxs03 = vec_and(vec_sr(qxs0, v6), lowMask); - vector signed char qxs10 = vec_and(qxs1, lowMask); - vector signed char qxs11 = vec_and(vec_sr(qxs1, v2), lowMask); - vector signed char qxs12 = vec_and(vec_sr(qxs1, v4), lowMask); - vector signed char qxs13 = vec_and(vec_sr(qxs1, v6), lowMask); - - //the 3rd bit - vector signed char qxh00 = vec_sl(vec_andc(v1, qxhs0), v2); - vector signed char qxh01 = vec_sl(vec_andc(v1, vec_sr(qxhs0, (vector unsigned char)v1)), v2); - vector signed char qxh02 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v2)), v2); - vector signed char qxh03 = vec_sl(vec_andc(v1, vec_sr(qxhs0, v3)), v2); - vector signed char qxh10 = vec_sl(vec_andc(v1, qxhs1), v2); - vector signed char qxh11 = vec_sl(vec_andc(v1, vec_sr(qxhs1, (vector unsigned char)v1)), v2); - vector signed char qxh12 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v2)), v2); - vector signed char qxh13 = vec_sl(vec_andc(v1, vec_sr(qxhs1, v3)), v2); - qxhs0 = vec_sr(qxhs0, v4); - qxhs1 = vec_sr(qxhs1, v4); - - vector signed char q3x00 = vec_sub(qxs00, qxh00); - vector signed char q3x01 = vec_sub(qxs01, qxh01); - vector signed char q3x02 = vec_sub(qxs02, qxh02); - vector signed char q3x03 = vec_sub(qxs03, qxh03); - vector signed char q3x10 = vec_sub(qxs10, qxh10); - vector signed char q3x11 = vec_sub(qxs11, qxh11); - vector signed char q3x12 = vec_sub(qxs12, qxh12); - vector signed char q3x13 = vec_sub(qxs13, qxh13); - - vector signed char q8y00 = vec_xl( 0, q8); - vector signed char q8y10 = vec_xl( 16, q8); - vector signed char q8y01 = vec_xl( 32, q8); - vector signed char q8y11 = vec_xl( 48, q8); - vector signed char q8y02 = vec_xl( 64, q8); - vector signed char q8y12 = vec_xl( 80, q8); - vector signed char q8y03 = vec_xl( 96, q8); - vector signed char q8y13 = vec_xl(112, q8); - q8 += 128; - - vector signed short vscales_h = vec_unpackh(vscales); - vector signed short vs0 = vec_splat(vscales_h, 0); - vector signed short vs1 = vec_splat(vscales_h, 1); - vector signed short vs2 = vec_splat(vscales_h, 2); - vector signed short vs3 = vec_splat(vscales_h, 3); - vector signed short vs4 = vec_splat(vscales_h, 4); - vector signed short vs5 = vec_splat(vscales_h, 5); - vector signed short vs6 = vec_splat(vscales_h, 6); - vector signed short vs7 = vec_splat(vscales_h, 7); - vscales = vec_sld(vscales, vscales, 8); - - vector signed short qv00 = vec_add(vec_mule(q3x00, q8y00), vec_mulo(q3x00, q8y00)); - vector signed short qv01 = vec_add(vec_mule(q3x01, q8y01), vec_mulo(q3x01, q8y01)); - vector signed short qv02 = vec_add(vec_mule(q3x02, q8y02), vec_mulo(q3x02, q8y02)); - vector signed short qv03 = vec_add(vec_mule(q3x03, q8y03), vec_mulo(q3x03, q8y03)); - vector signed short qv10 = vec_add(vec_mule(q3x10, q8y10), vec_mulo(q3x10, q8y10)); - vector signed short qv11 = vec_add(vec_mule(q3x11, q8y11), vec_mulo(q3x11, q8y11)); - vector signed short qv12 = vec_add(vec_mule(q3x12, q8y12), vec_mulo(q3x12, q8y12)); - vector signed short qv13 = vec_add(vec_mule(q3x13, q8y13), vec_mulo(q3x13, q8y13)); - - vsumi0 = vec_msum(qv00, vs0, vsumi0); - vsumi1 = vec_msum(qv01, vs2, vsumi1); - vsumi2 = vec_msum(qv02, vs4, vsumi2); - vsumi3 = vec_msum(qv03, vs6, vsumi3); - vsumi4 = vec_msum(qv10, vs1, vsumi4); - vsumi5 = vec_msum(qv11, vs3, vsumi5); - vsumi6 = vec_msum(qv12, vs5, vsumi6); - vsumi7 = vec_msum(qv13, vs7, vsumi7); - } - - vsumi0 = vec_add(vsumi0, vsumi4); - vsumi1 = vec_add(vsumi1, vsumi5); - vsumi2 = vec_add(vsumi2, vsumi6); - vsumi3 = vec_add(vsumi3, vsumi7); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#elif defined __loongarch_asx - - const __m128i m32 = __lsx_vreplgr2vr_b(32); - - __m256 acc = (__m256)__lasx_xvldi(0); - - uint32_t aux[3]; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const uint8_t * GGML_RESTRICT q3 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - // Set up scales - memcpy(aux, x[i].scales, 12); - __m128i scales128 = lsx_set_w( - ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), - ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), - (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), - (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); - scales128 = __lsx_vsub_b(scales128, m32); - - const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; - const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask)); - - // high bit - const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0); - - // integer accumulator - __m256i sumi = __lasx_xvldi(0); - - for (int j = 0; j < QK_K/128; ++j) { - // load low 2 bits - const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32; - - // prepare low and high bits - const __m256i q3l_0 = __lasx_xvandi_b(q3bits, 3); - const __m256i q3l_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 2), 3); - const __m256i q3l_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 4), 3); - const __m256i q3l_3 = __lasx_xvsrli_b(q3bits, 6); - const __m256i q3h_0 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 0), 0), 2); - const __m256i q3h_1 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 1), 0), 2); - const __m256i q3h_2 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 2), 0), 2); - const __m256i q3h_3 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 3), 0), 2); - const __m256i q3_0 = __lasx_xvor_v(q3h_0, q3l_0); - const __m256i q3_1 = __lasx_xvor_v(q3h_1, q3l_1); - const __m256i q3_2 = __lasx_xvor_v(q3h_2, q3l_2); - const __m256i q3_3 = __lasx_xvor_v(q3h_3, q3l_3); - - // load Q8 quants - const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - - __m256i p16_0 = lasx_madd_h_b(q8_0, q3_0); - __m256i p16_1 = lasx_madd_h_b(q8_1, q3_1); - __m256i p16_2 = lasx_madd_h_b(q8_2, q3_2); - __m256i p16_3 = lasx_madd_h_b(q8_3, q3_3); - - // multiply with scales - p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0); - p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1); - p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2); - p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3); - - // accumulate - p16_0 = __lasx_xvadd_w(p16_0, p16_1); - p16_2 = __lasx_xvadd_w(p16_2, p16_3); - sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2)); - } - // multiply with block scale and accumulate - acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc); - } - - *s = hsum_float_8(acc); - -#else - // scalar version - // This function is written like this so the compiler can manage to vectorize most of it - // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the - // manually vectorized version above. Every other version I tried would run at least 4 times slower. - // The ideal situation would be if we could just write the code once, and the compiler would - // automatically produce the best possible set of machine instructions, instead of us having to manually - // write vectorized versions for AVX, ARM_NEON, etc. - - int8_t aux8[QK_K]; - int16_t aux16[8]; - float sums [8]; - int32_t aux32[8]; - memset(sums, 0, 8*sizeof(float)); - - uint32_t auxs[4]; - const int8_t * scales = (const int8_t*)auxs; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * GGML_RESTRICT q3 = x[i].qs; - const uint8_t * GGML_RESTRICT hm = x[i].hmask; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - memset(aux32, 0, 8*sizeof(int32_t)); - int8_t * GGML_RESTRICT a = aux8; - uint8_t m = 1; - for (int j = 0; j < QK_K; j += 128) { - for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3; - for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); - a += 32; m <<= 1; - for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3; - for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); - a += 32; m <<= 1; - for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3; - for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); - a += 32; m <<= 1; - for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3; - for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); - a += 32; m <<= 1; - q3 += 32; - } - a = aux8; - - memcpy(auxs, x[i].scales, 12); - uint32_t tmp = auxs[2]; - auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); - auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); - auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); - auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); - for (int j = 0; j < QK_K/16; ++j) { - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; - q8 += 8; a += 8; - } - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; - -#endif - -} - -void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q4_K * GGML_RESTRICT x = vx; - const block_q8_K * GGML_RESTRICT y = vy; - - const int nb = n / QK_K; - - static const uint32_t kmask1 = 0x3f3f3f3f; - static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; - - uint32_t utmp[4]; - -#ifdef __ARM_FEATURE_SVE - float sumf = 0; - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); - - memcpy(utmp, x[i].scales, K_SCALE_SIZE); - - uint32x2_t mins8 = { 0 }; - mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0); - mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1); - - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[0] &= kmask1; - - const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); - const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), - vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); - sumf -= dmin * vaddvq_s32(prod); - - const uint8_t * scales = (const uint8_t *)utmp; - - const uint8_t * GGML_RESTRICT q4 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - const int vector_length = ggml_cpu_get_sve_cnt()*8; - const svuint8_t m4b = svdup_n_u8(0xf); - const svint32_t mzero = svdup_n_s32(0); - svint32_t sumi1 = svdup_n_s32(0); - svint32_t sumi1_1 = svdup_n_s32(0); - svint32_t sumi1_2 = svdup_n_s32(0); - svint32_t sumi2 = svdup_n_s32(0); - svint32_t sumi2_1 = svdup_n_s32(0); - svint32_t sumi2_2 = svdup_n_s32(0); - switch (vector_length) { - case 128: - { - for (int j = 0; j < QK_K/64; ++j) { - svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b)); - svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; - sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]); - q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b)); - q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; - sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]); - - q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4)); - q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; - sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]); - q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4)); - q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16; - sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]); - q4 += 32; - } - sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2); - sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2); - sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2))); - } break; - case 256: - case 512: - { - for (int j = 0; j < QK_K/64; ++j) { - const svuint8_t q4bits = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32; - svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b)); - svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32; - sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]); - - q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4)); - q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32; - sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]); - } - sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2))); - } break; - default: - assert(false && "Unsupported vector length"); - break; - } - } - *s = sumf; -#elif defined __ARM_NEON - const uint8x16_t m4b = vdupq_n_u8(0xf); - const int32x4_t mzero = vdupq_n_s32(0); - - ggml_int8x16x2_t q4bytes; - ggml_int8x16x2_t q8bytes; - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); - - memcpy(utmp, x[i].scales, 12); - - uint32x2_t mins8 = { 0 }; - mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0); - mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1); - - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[0] &= kmask1; - - const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); - const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), - vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); - sumf -= dmin * vaddvq_s32(prod); - - const uint8_t * scales = (const uint8_t *)utmp; - - const uint8_t * GGML_RESTRICT q4 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - int32_t sumi1 = 0; - int32_t sumi2 = 0; - - for (int j = 0; j < QK_K/64; ++j) { - const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32; - - q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; - q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); - q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); - - const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); - sumi1 += vaddvq_s32(p1) * scales[2*j+0]; - - q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32; - q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); - q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); - - const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); - - sumi2 += vaddvq_s32(p2) * scales[2*j+1]; - } - - sumf += d * (sumi1 + sumi2); - - } - - *s = sumf; - -#elif defined __wasm_simd128__ - const uint8_t * scales = (const uint8_t*)&utmp[0]; - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Corrected sign - - const uint8_t * GGML_RESTRICT q4 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - // Process scales and mins - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - // Sum mins * q8sums - int32_t sumi = 0; - const int16_t * GGML_RESTRICT q8sums = y[i].bsums; - const uint8_t * m = (const uint8_t *)&utmp[2]; - for (int j = 0; j < 16; j += 2) { - sumi += (q8sums[j] + q8sums[j+1]) * m[j/2]; - } - sumf -= dmin * sumi; - - int32_t sumi1 = 0; - int32_t sumi2 = 0; - - for (int j = 0; j < QK_K/64; ++j) { - // Load 64 4-bit weights (32 bytes) - const v128_t q4x0 = wasm_v128_load(q4); - const v128_t q4x1 = wasm_v128_load(q4 + 16); - q4 += 32; - - // Split into low/high nibbles - const v128_t q4l0 = wasm_v128_and(q4x0, wasm_i8x16_splat(0x0F)); - const v128_t q4h0 = wasm_u8x16_shr(q4x0, 4); - const v128_t q4l1 = wasm_v128_and(q4x1, wasm_i8x16_splat(0x0F)); - const v128_t q4h1 = wasm_u8x16_shr(q4x1, 4); - - // Load 64 8-bit values (64 bytes) - const v128_t q8x0 = wasm_v128_load(q8); - const v128_t q8x1 = wasm_v128_load(q8 + 16); - const v128_t q8x2 = wasm_v128_load(q8 + 32); - const v128_t q8x3 = wasm_v128_load(q8 + 48); - q8 += 64; - - // Low nibble products - v128_t vacc1 = wasm_i32x4_dot_i16x8( - wasm_i16x8_extend_low_i8x16(q4l0), - wasm_i16x8_extend_low_i8x16(q8x0) - ); - vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( - wasm_i16x8_extend_high_i8x16(q4l0), - wasm_i16x8_extend_high_i8x16(q8x0) - )); - vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( - wasm_i16x8_extend_low_i8x16(q4l1), - wasm_i16x8_extend_low_i8x16(q8x1) - )); - vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8( - wasm_i16x8_extend_high_i8x16(q4l1), - wasm_i16x8_extend_high_i8x16(q8x1) - )); - - // High nibble products - v128_t vacc2 = wasm_i32x4_dot_i16x8( - wasm_i16x8_extend_low_i8x16(q4h0), - wasm_i16x8_extend_low_i8x16(q8x2) - ); - vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( - wasm_i16x8_extend_high_i8x16(q4h0), - wasm_i16x8_extend_high_i8x16(q8x2) - )); - vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( - wasm_i16x8_extend_low_i8x16(q4h1), - wasm_i16x8_extend_low_i8x16(q8x3) - )); - vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8( - wasm_i16x8_extend_high_i8x16(q4h1), - wasm_i16x8_extend_high_i8x16(q8x3) - )); - - // Accumulate scaled results - int32_t vacc1_sum = wasm_i32x4_extract_lane(vacc1, 0) + wasm_i32x4_extract_lane(vacc1, 1) + - wasm_i32x4_extract_lane(vacc1, 2) + wasm_i32x4_extract_lane(vacc1, 3); - sumi1 += vacc1_sum * scales[2*j]; - - int32_t vacc2_sum = wasm_i32x4_extract_lane(vacc2, 0) + wasm_i32x4_extract_lane(vacc2, 1) + - wasm_i32x4_extract_lane(vacc2, 2) + wasm_i32x4_extract_lane(vacc2, 3); - sumi2 += vacc2_sum * scales[2*j+1]; - } - - sumf += d * (sumi1 + sumi2); - } - - *s = sumf; - -#elif defined __AVX2__ - - const __m256i m4 = _mm256_set1_epi8(0xF); - - __m256 acc = _mm256_setzero_ps(); - __m128 acc_m = _mm_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const uint8_t * GGML_RESTRICT q4 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); - - const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); - const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); - const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); - acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m); - - const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); - const __m256i scales = MM256_SET_M128I(sc128, sc128); - - __m256i sumi = _mm256_setzero_si256(); - - for (int j = 0; j < QK_K/64; ++j) { - - const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); - const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); - - const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; - const __m256i q4l = _mm256_and_si256(q4bits, m4); - const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); - - const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); - p16l = _mm256_madd_epi16(scale_l, p16l); - - const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); - p16h = _mm256_madd_epi16(scale_h, p16h); - const __m256i sumj = _mm256_add_epi32(p16l, p16h); - - sumi = _mm256_add_epi32(sumi, sumj); - } - - __m256 vd = _mm256_set1_ps(d); - acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); - - } - - acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); - acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); - - *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); - -#elif defined __AVX__ - - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i m2 = _mm_set1_epi8(0x2); - - __m256 acc = _mm256_setzero_ps(); - __m128 acc_m = _mm_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - const uint8_t * GGML_RESTRICT q4 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); - const __m128i scales = _mm_cvtepu8_epi16(utmps); - const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); - - const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); - const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); - const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); - const __m128i prod = _mm_madd_epi16(mins, q8s); - acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m); - - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); - - __m128i shuffle = _mm_set1_epi16(0x0100); - for (int j = 0; j < QK_K/64; ++j) { - - const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi16(shuffle, m2); - const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi16(shuffle, m2); - - __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4l_0 = _mm_and_si128(q4bits, m4); - const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); - q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4l_1 = _mm_and_si128(q4bits, m4); - const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); - - const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0); - p16l = _mm_madd_epi16(scale_l, p16l); - sumi_0 = _mm_add_epi32(sumi_0, p16l); - const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - p16l = _mm_maddubs_epi16(q4l_1, q8l_1); - p16l = _mm_madd_epi16(scale_l, p16l); - sumi_1 = _mm_add_epi32(sumi_1, p16l); - - const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0); - p16h = _mm_madd_epi16(scale_h, p16h); - sumi_0 = _mm_add_epi32(sumi_0, p16h); - const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - p16h = _mm_maddubs_epi16(q4h_1, q8h_1); - p16h = _mm_madd_epi16(scale_h, p16h); - sumi_1 = _mm_add_epi32(sumi_1, p16h); - - } - - __m256 vd = _mm256_set1_ps(d); - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); - - } - - acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); - acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); - - *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); - -#elif defined __riscv_v_intrinsic - - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; - - const int vector_length = __riscv_vlenb() * 8; - float sumf = 0; - - switch (vector_length) { - case 256: - for (int i = 0; i < nb; ++i) { - - size_t vl = 8; - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); - vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); - vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); - vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); - vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); - - vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); - - const uint8_t * GGML_RESTRICT q4 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - vl = 32; - - int32_t sum_1 = 0; - int32_t sum_2 = 0; - - vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); - - for (int j = 0; j < QK_K/64; ++j) { - // load Q4 - vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); - - // load Q8 and multiply it with lower Q4 nibble - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); - vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); - vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); - - sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; - - // load Q8 and multiply it with upper Q4 nibble - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); - vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); - vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); - vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); - - sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; - - q4 += 32; q8 += 64; - - } - - sumf += d*(sum_1 + sum_2); - - } - break; - case 128: - for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - int tmp, tmp2, sumi; - __asm__ __volatile__( - "vsetivli zero, 12, e8, m1\n\t" - "vle8.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]} - "vsetivli zero, 4, e32, m1\n\t" - "vslidedown.vi v2, v1, 2\n\t" - "vmv1r.v v3, v2\n\t" - "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]} - "vsetivli zero, 2, e32, m1\n\t" - "vmv.v.i v4, 4\n\t" - "vand.vx v8, v1, %[kmask1]\n\t" - "vslide1up.vx v5, v4, zero\n\t" // {0, 4} - "vsrl.vi v6, v1, 6\n\t" - "vsrl.vv v7, v2, v5\n\t" - "vand.vx v0, v6, %[kmask3]\n\t" - "vand.vx v2, v7, %[kmask2]\n\t" - "vsll.vi v6, v0, 4\n\t" - "li %[t2], 8\n\t" - "addi %[t1], %[utmp], 4\n\t" - "vor.vv v1, v6, v2\n\t" - "vsse32.v v8, (%[utmp]), %[t2]\n\t" - "vsse32.v v1, (%[t1]), %[t2]\n\t" - "vsetivli zero, 8, e16, m1\n\t" - "vle32.v v2, (%[bsums])\n\t" - "vnsrl.wi v0, v2, 0\n\t" - "vnsrl.wi v1, v2, 16\n\t" - "vadd.vv v2, v0, v1\n\t" - "vle8.v v3, (%[mins])\n\t" - "vzext.vf2 v4, v3\n\t" - "vwmul.vv v6, v4, v2\n\t" - "vmv.v.x v0, zero\n\t" - "vsetivli zero, 8, e32, m2\n\t" - "vredsum.vs v0, v6, v0\n\t" - "vmv.x.s %[sumi], v0" - : [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi) - : [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp) - , [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1) - , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); - sumf -= dmin * sumi; - - const uint8_t * restrict q4 = x[i].qs; - const int8_t * restrict q8 = y[i].qs; - - sumi = 0; - const uint8_t * scale = scales; - - for (int j = 0; j < QK_K/128; ++j) { - int vl128 = 128, vl64 = 64, vl32 = 32; - __asm__ __volatile__( - "vsetvli zero, %[vl128], e8, m8\n\t" - "vle8.v v8, (%[q8])\n\t" - "vsetvli zero, %[vl64], e8, m4\n\t" - "vle8.v v0, (%[q4])\n\t" - "vsrl.vi v4, v0, 4\n\t" - "vand.vi v0, v0, 0xF\n\t" - "vsetvli zero, %[vl32], e8, m2\n\t" - "vwmul.vv v28, v6, v14\n\t" - "vwmul.vv v20, v4, v10\n\t" - "vwmul.vv v24, v2, v12\n\t" - "vwmul.vv v16, v0, v8\n\t" - "vsetivli zero, 4, e32, m1\n\t" - "vle8.v v2, (%[scale])\n\t" - "vmv.v.x v0, zero\n\t" - "vzext.vf4 v1, v2\n\t" - "vsetvli zero, %[vl32], e16, m4\n\t" - "vwredsum.vs v6, v24, v0\n\t" - "vwredsum.vs v7, v28, v0\n\t" - "vwredsum.vs v4, v16, v0\n\t" - "vwredsum.vs v5, v20, v0\n\t" - "vsetivli zero, 4, e32, m1\n\t" - "vslideup.vi v6, v7, 1\n\t" - "vslideup.vi v4, v5, 1\n\t" - "vslideup.vi v4, v6, 2\n\t" - "vmul.vv v8, v4, v1\n\t" - "vredsum.vs v0, v8, v0\n\t" - "vmv.x.s %[tmp], v0\n\t" - "add %[sumi], %[sumi], %[tmp]" - : [tmp] "=&r" (tmp), [sumi] "+&r" (sumi) - : [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32) - , [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); - - q4 += 64; q8 += 128; scale += 4; - } - - sumf += d * sumi; - } - break; - default: - assert(false && "Unsupported vector length"); - break; - } - - *s = sumf; - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector signed char lowMask1 = vec_splats((int8_t)0x3f); - const vector signed char lowMask2 = vec_splats((int8_t)0x30); - const vector int v0 = vec_splats((int32_t)0); - const vector unsigned char v2 = vec_splats((uint8_t)2); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin)); - vector float vdmin = vec_mul(vxmin, vyd); - - vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); - vector signed short q8ysums1 = vec_xl(16, y[i].bsums); - - UNUSED(kmask1); - UNUSED(kmask2); - UNUSED(kmask3); - UNUSED(utmp); - - vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); - vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2); - vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); - vector signed char u3 = vec_sr(u2, v4); - - vector signed char u30 = u1; - vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3); - - u1 = vec_and(u0, lowMask1); - u2 = vec_or(u30, u31); - - vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2); - - vector signed short vscales = vec_unpackh(utmps); - vector signed short q4xmins = vec_unpackl(utmps); - vector signed short q4xmins0 = vec_mergeh(q4xmins, q4xmins); - vector signed short q4xmins1 = vec_mergel(q4xmins, q4xmins); - - vector signed int prod0 = vec_mule(q4xmins0, q8ysums0); - vector signed int prod1 = vec_mule(q4xmins1, q8ysums1); - vector signed int prod2 = vec_mulo(q4xmins0, q8ysums0); - vector signed int prod3 = vec_mulo(q4xmins1, q8ysums1); - - vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0); - vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1); - vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); - vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - - const uint8_t * GGML_RESTRICT q4 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - for (int j = 0; j < QK_K/64; j+=2) { - __builtin_prefetch(q4, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector signed char qxs0 = (vector signed char)vec_xl( 0, q4); - vector signed char qxs1 = (vector signed char)vec_xl(16, q4); - vector signed char qxs2 = (vector signed char)vec_xl(32, q4); - vector signed char qxs3 = (vector signed char)vec_xl(48, q4); - q4 += 64; - - vector unsigned char q4x00 = (vector unsigned char)vec_and(qxs0, lowMask); - vector unsigned char q4x01 = (vector unsigned char)vec_sr(qxs0, v4); - vector unsigned char q4x10 = (vector unsigned char)vec_and(qxs1, lowMask); - vector unsigned char q4x11 = (vector unsigned char)vec_sr(qxs1, v4); - vector unsigned char q4x20 = (vector unsigned char)vec_and(qxs2, lowMask); - vector unsigned char q4x21 = (vector unsigned char)vec_sr(qxs2, v4); - vector unsigned char q4x30 = (vector unsigned char)vec_and(qxs3, lowMask); - vector unsigned char q4x31 = (vector unsigned char)vec_sr(qxs3, v4); - - vector signed char q8y00 = vec_xl( 0, q8); - vector signed char q8y10 = vec_xl( 16, q8); - vector signed char q8y01 = vec_xl( 32, q8); - vector signed char q8y11 = vec_xl( 48, q8); - vector signed char q8y20 = vec_xl( 64, q8); - vector signed char q8y30 = vec_xl( 80, q8); - vector signed char q8y21 = vec_xl( 96, q8); - vector signed char q8y31 = vec_xl(112, q8); - q8 += 128; - - vector signed int qv00 = vec_msum(q8y00, q4x00, v0); - vector signed int qv01 = vec_msum(q8y01, q4x01, v0); - vector signed int qv10 = vec_msum(q8y10, q4x10, v0); - vector signed int qv11 = vec_msum(q8y11, q4x11, v0); - vector signed int qv20 = vec_msum(q8y20, q4x20, v0); - vector signed int qv21 = vec_msum(q8y21, q4x21, v0); - vector signed int qv30 = vec_msum(q8y30, q4x30, v0); - vector signed int qv31 = vec_msum(q8y31, q4x31, v0); - - vector signed int vscales_h = vec_unpackh(vscales); - vector signed int vs0 = vec_splat(vscales_h, 0); - vector signed int vs1 = vec_splat(vscales_h, 1); - vector signed int vs2 = vec_splat(vscales_h, 2); - vector signed int vs3 = vec_splat(vscales_h, 3); - vscales = vec_sld(vscales, vscales, 8); - - vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0); - vsumi1 = vec_add(vec_mul(qv01, vs1), vsumi1); - vsumi2 = vec_add(vec_mul(qv20, vs2), vsumi2); - vsumi3 = vec_add(vec_mul(qv21, vs3), vsumi3); - - vsumi0 = vec_add(vec_mul(qv10, vs0), vsumi0); - vsumi1 = vec_add(vec_mul(qv11, vs1), vsumi1); - vsumi2 = vec_add(vec_mul(qv30, vs2), vsumi2); - vsumi3 = vec_add(vec_mul(qv31, vs3), vsumi3); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#elif defined __loongarch_asx - - __m256 acc = (__m256)__lasx_xvldi(0); - __m128 acc_m = (__m128)__lsx_vldi(0); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const uint8_t * GGML_RESTRICT q4 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]); - const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128); - const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0); - - const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0); - const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1)); - const __m128i prod = lsx_madd_h(mins128, q8s); - acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m); - - const __m256i scales = lasx_insertf128(scales128, scales128); - - __m256i sumi = __lasx_xvldi(0); - - for (int j = 0; j < QK_K/64; ++j) { - - const __m256i scale_l = lasx_xvrepl128vei_h(scales, 2 * j + 0); - const __m256i scale_h = lasx_xvrepl128vei_h(scales, 2 * j + 1); - - const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; - const __m256i q4l = __lasx_xvandi_b(q4bits, 0xf); - const __m256i q4h = __lasx_xvsrli_b(q4bits, 4); - - const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - __m256i p16l = lasx_madd_h_b(q4l, q8l); - p16l = lasx_madd_h(scale_l, p16l); - - const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - __m256i p16h = lasx_madd_h_b(q4h, q8h); - p16h = lasx_madd_h(scale_h, p16h); - const __m256i sumj = __lasx_xvadd_w(p16l, p16h); - - sumi = __lasx_xvadd_w(sumi, sumj); - } - - __m256 vd = __lasx_xvreplfr2vr_s(d); - acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc); - - } - - acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vpermi_w((__m128i)acc_m, (__m128i)acc_m, 0xee)); - __m128i tmp1 = __lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w((__m128i)acc_m, 1), 0); - acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1); - - - *s = hsum_float_8(acc) + ((v4f32)acc_m)[0]; -#elif defined(__VXE__) || defined(__VXE2__) - const uint8x16_t v_lm = vec_splat_u8(0x0F); - const int32x4_t v_z = vec_splat_s32(0); - - uint8x16_t v_x[2]; - int8x16_t v_xl[2]; - int8x16_t v_y[2]; - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums); - const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums); - const int16x8_t v_ysums = vec_padd_s16(v_ysumsl, v_ysumsh); - - memcpy(utmp, x[i].scales, 12); - - uint32x4_t v_mins8 = { 0 }; - v_mins8 = vec_insert(utmp[1] & kmask1, v_mins8, 0); - v_mins8 = vec_insert(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), v_mins8, 1); - - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[0] &= kmask1; - - const int16x8_t v_minsh = (int16x8_t)vec_unpackh((uint8x16_t)v_mins8); - - const int32x4_t v_minso = vec_mulo(v_ysums, v_minsh); - const int32x4_t v_minse = vec_mule(v_ysums, v_minsh); - const int32x4_t v_mins = v_minso + v_minse; - sumf -= dmin * (v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]); - - const uint8_t * scales = (const uint8_t *)utmp; - const uint8_t * GGML_RESTRICT x0 = x[i].qs; - const int8_t * GGML_RESTRICT y0 = y[i].qs; - - int32_t sumi1 = 0; - int32_t sumi2 = 0; - - for (int j = 0; j < QK_K/64; ++j) { - v_x[0] = vec_xl(0 , x0); - v_x[1] = vec_xl(16, x0); - x0 += 32; - - v_y[0] = vec_xl(0 , y0); - v_y[1] = vec_xl(16, y0); - y0 += 32; - - v_xl[0] = (int8x16_t)vec_and(v_x[0], v_lm); - v_xl[1] = (int8x16_t)vec_and(v_x[1], v_lm); - - const int32x4_t p1 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]); - sumi1 += (p1[0] + p1[1] + p1[2] + p1[3]) * scales[2*j+0]; - - v_y[0] = vec_xl(0 , y0); - v_y[1] = vec_xl(16, y0); - y0 += 32; - - v_xl[0] = (int8x16_t)vec_sr(v_x[0], 4); - v_xl[1] = (int8x16_t)vec_sr(v_x[1], 4); - - const int32x4_t p2 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]); - sumi2 += (p2[0] + p2[1] + p2[2] + p2[3]) * scales[2*j+1]; - } - - sumf += d * (sumi1 + sumi2); - } - - *s = sumf; -#else - - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; - - int8_t aux8[QK_K]; - int16_t aux16[8]; - float sums [8]; - int32_t aux32[8]; - memset(sums, 0, 8*sizeof(float)); - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * GGML_RESTRICT q4 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - memset(aux32, 0, 8*sizeof(int32_t)); - int8_t * GGML_RESTRICT a = aux8; - for (int j = 0; j < QK_K/64; ++j) { - for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); - a += 32; - for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); - a += 32; q4 += 32; - } - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - int sumi = 0; - for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; - a = aux8; - int is = 0; - for (int j = 0; j < QK_K/32; ++j) { - int32_t scale = scales[is++]; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - } - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; - const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; - sumf -= dmin * sumi; - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; -#endif -} - -void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q5_K * GGML_RESTRICT x = vx; - const block_q8_K * GGML_RESTRICT y = vy; - - const int nb = n / QK_K; - - static const uint32_t kmask1 = 0x3f3f3f3f; - static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; - - uint32_t utmp[4]; - -#ifdef __ARM_NEON - const uint8x16_t m4b = vdupq_n_u8(0xf); - const uint8x16_t mone = vdupq_n_u8(1); - const uint8x16_t mtwo = vdupq_n_u8(2); - const int32x4_t mzero = vdupq_n_s32(0); - - ggml_int8x16x4_t q5bytes; - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8); - const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8)); - const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), - vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); - int32_t sumi_mins = vaddvq_s32(prod); - - const uint8_t * scales = (const uint8_t *)utmp; - - const uint8_t * GGML_RESTRICT q5 = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); - - ggml_uint8x16x4_t q5h; - - int32_t sumi = 0; - - for (int j = 0; j < QK_K/64; ++j) { - - const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32; - const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; - - q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); - q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); - q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3); - q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3); - qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2); - qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2); - - q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0])); - q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1])); - q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2])); - q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3])); - - sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++; - sumi += vaddvq_s32(ggml_vdotq_s32(ggml_vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++; - } - - sumf += d * sumi - dmin * sumi_mins; - } - - *s = sumf; - -#elif defined __AVX2__ - - const __m256i m4 = _mm256_set1_epi8(0xF); - const __m128i mzero = _mm_setzero_si128(); - const __m256i mone = _mm256_set1_epi8(1); - - __m256 acc = _mm256_setzero_ps(); - - float summs = 0.f; - - for (int i = 0; i < nb; ++i) { - const uint8_t * GGML_RESTRICT q5 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); - - const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); - const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); - const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); - const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); - summs += dmin * _mm_extract_epi32(hsum, 0); - - const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); - const __m256i scales = MM256_SET_M128I(sc128, sc128); - - const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh); - __m256i hmask = mone; - - __m256i sumi = _mm256_setzero_si256(); - - int bit = 0; - - for (int j = 0; j < QK_K/64; ++j) { - - const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); - const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); - - const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32; - - const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); - const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); - const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0); - hmask = _mm256_slli_epi16(hmask, 1); - - const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); - const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); - const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1); - hmask = _mm256_slli_epi16(hmask, 1); - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - - __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0); - __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1); - - p16_0 = _mm256_madd_epi16(scale_0, p16_0); - p16_1 = _mm256_madd_epi16(scale_1, p16_1); - - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); - - } - - __m256 vd = _mm256_set1_ps(d); - acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); - - } - - *s = hsum_float_8(acc) + summs; - -#elif defined __AVX__ - - const __m128i m4 = _mm_set1_epi8(0xF); - const __m128i mzero = _mm_setzero_si128(); - const __m128i mone = _mm_set1_epi8(1); - const __m128i m2 = _mm_set1_epi8(2); - - __m256 acc = _mm256_setzero_ps(); - - float summs = 0.f; - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - const uint8_t * GGML_RESTRICT q5 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); - const __m128i scales = _mm_cvtepu8_epi16(utmps); - const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); - - const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); - const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); - const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); - const __m128i prod = _mm_madd_epi16(mins, q8s); - const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); - summs += dmin * _mm_extract_epi32(hsum, 0); - - const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]); - const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]); - __m128i hmask = mone; - - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); - - int bit = 0; - - __m128i shuffle = _mm_set1_epi16(0x0100); - for (int j = 0; j < QK_K/64; ++j) { - - const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi16(shuffle, m2); - const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle); - shuffle = _mm_add_epi16(shuffle, m2); - - const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; - const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; - - __m128i q5l_0 = _mm_and_si128(q5bits_0, m4); - __m128i q5l_1 = _mm_and_si128(q5bits_1, m4); - __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); - __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); - __m128i q5_0 = _mm_add_epi8(q5l_0, q5h_0); - __m128i q5_1 = _mm_add_epi8(q5l_1, q5h_1); - hmask = _mm_slli_epi16(hmask, 1); - - __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0); - __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1); - p16_0 = _mm_madd_epi16(scale_0, p16_0); - p16_1 = _mm_madd_epi16(scale_0, p16_1); - - q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4); - q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4); - q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); - q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); - q5_0 = _mm_add_epi8(q5l_0, q5h_0); - q5_1 = _mm_add_epi8(q5l_1, q5h_1); - hmask = _mm_slli_epi16(hmask, 1); - - q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0); - __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1); - p16_2 = _mm_madd_epi16(scale_1, p16_2); - p16_3 = _mm_madd_epi16(scale_1, p16_3); - - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); - - } - - __m256 vd = _mm256_set1_ps(d); - __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); - - } - - *s = hsum_float_8(acc) + summs; - -#elif defined __wasm_simd128__ - //const uint8_t * scales = (const uint8_t*)&utmp[0]; - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Fixed sign - - const uint8_t * GGML_RESTRICT q5 = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - // Process scales and mins - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - // Sum mins * q8sums - int32_t sumi_mins = 0; - const int16_t * GGML_RESTRICT q8sums = y[i].bsums; - const uint8_t * m = (const uint8_t *)&utmp[2]; - for (int j = 0; j < 16; j += 2) { - sumi_mins += (q8sums[j] + q8sums[j+1]) * m[j/2]; - } - sumf -= dmin * sumi_mins; // Correct subtraction - - v128_t qh0 = wasm_v128_load(qh); - v128_t qh1 = wasm_v128_load(qh + 16); - const uint8_t * sc = (const uint8_t *)utmp; - - int32_t sumi = 0; - - for (int j = 0; j < QK_K/64; ++j) { - const int shift = j * 2; - v128_t qh_shift0 = wasm_u8x16_shr(qh0, shift); - v128_t qh_shift1 = wasm_u8x16_shr(qh1, shift); - - v128_t qh_low0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x01)), 4); - v128_t qh_high0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x02)), 3); - v128_t qh_low1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x01)), 4); - v128_t qh_high1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x02)), 3); - - v128_t q5_0 = wasm_v128_load(q5); - v128_t q5_1 = wasm_v128_load(q5 + 16); - q5 += 32; - - v128_t q5l_0 = wasm_v128_or(wasm_v128_and(q5_0, wasm_i8x16_splat(0x0F)), qh_low0); - v128_t q5h_0 = wasm_v128_or(wasm_u8x16_shr(q5_0, 4), qh_high0); - v128_t q5l_1 = wasm_v128_or(wasm_v128_and(q5_1, wasm_i8x16_splat(0x0F)), qh_low1); - v128_t q5h_1 = wasm_v128_or(wasm_u8x16_shr(q5_1, 4), qh_high1); - - v128_t q8_0 = wasm_v128_load(q8); - v128_t q8_1 = wasm_v128_load(q8 + 16); - v128_t q8_2 = wasm_v128_load(q8 + 32); - v128_t q8_3 = wasm_v128_load(q8 + 48); - q8 += 64; - - // Process low quants - v128_t pl0 = wasm_i32x4_dot_i16x8( - wasm_i16x8_extend_low_i8x16(q5l_0), - wasm_i16x8_extend_low_i8x16(q8_0) - ); - pl0 = wasm_i32x4_add(pl0, wasm_i32x4_dot_i16x8( - wasm_i16x8_extend_high_i8x16(q5l_0), - wasm_i16x8_extend_high_i8x16(q8_0) - )); - v128_t pl1 = wasm_i32x4_dot_i16x8( - wasm_i16x8_extend_low_i8x16(q5l_1), - wasm_i16x8_extend_low_i8x16(q8_1) - ); - pl1 = wasm_i32x4_add(pl1, wasm_i32x4_dot_i16x8( - wasm_i16x8_extend_high_i8x16(q5l_1), - wasm_i16x8_extend_high_i8x16(q8_1) - )); - v128_t sum_low = wasm_i32x4_add(pl0, pl1); - - // Process high quants - v128_t ph0 = wasm_i32x4_dot_i16x8( - wasm_i16x8_extend_low_i8x16(q5h_0), - wasm_i16x8_extend_low_i8x16(q8_2) - ); - ph0 = wasm_i32x4_add(ph0, wasm_i32x4_dot_i16x8( - wasm_i16x8_extend_high_i8x16(q5h_0), - wasm_i16x8_extend_high_i8x16(q8_2) - )); - v128_t ph1 = wasm_i32x4_dot_i16x8( - wasm_i16x8_extend_low_i8x16(q5h_1), - wasm_i16x8_extend_low_i8x16(q8_3) - ); - ph1 = wasm_i32x4_add(ph1, wasm_i32x4_dot_i16x8( - wasm_i16x8_extend_high_i8x16(q5h_1), - wasm_i16x8_extend_high_i8x16(q8_3) - )); - v128_t sum_high = wasm_i32x4_add(ph0, ph1); - - // Accumulate with scale factors - int32_t sl = wasm_i32x4_extract_lane(sum_low, 0) + wasm_i32x4_extract_lane(sum_low, 1) + - wasm_i32x4_extract_lane(sum_low, 2) + wasm_i32x4_extract_lane(sum_low, 3); - int32_t sh = wasm_i32x4_extract_lane(sum_high, 0) + wasm_i32x4_extract_lane(sum_high, 1) + - wasm_i32x4_extract_lane(sum_high, 2) + wasm_i32x4_extract_lane(sum_high, 3); - - sumi += sl * sc[2*j] + sh * sc[2*j+1]; - } - - sumf += d * sumi; - } - - *s = sumf; - -#elif defined __riscv_v_intrinsic - - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; - - float sumf = 0; - float sums = 0.0; - - size_t vl; - - for (int i = 0; i < nb; ++i) { - - vl = 8; - - const uint8_t * GGML_RESTRICT q5 = x[i].qs; - const uint8_t * GGML_RESTRICT hm = x[i].qh; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; - - vint16m1_t q8sums_0 = __riscv_vlse16_v_i16m1(y[i].bsums, 4, vl); - vint16m1_t q8sums_1 = __riscv_vlse16_v_i16m1(y[i].bsums+1, 4, vl); - vint16m1_t q8sums = __riscv_vadd_vv_i16m1(q8sums_0, q8sums_1, vl); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - vuint8mf2_t mins8 = __riscv_vle8_v_u8mf2(mins, vl); - vint16m1_t v_mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); - vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, v_mins, vl); - - vint32m1_t sumi = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); - - vl = 32; - int32_t aux32 = 0; - int is = 0; - - uint8_t m = 1; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m2_t vqh = __riscv_vle8_v_u8m2(hm, vl); - - for (int j = 0; j < QK_K/64; ++j) { - // load Q5 and Q8 - vuint8m2_t q5_x = __riscv_vle8_v_u8m2(q5, vl); - vint8m2_t q8_y1 = __riscv_vle8_v_i8m2(q8, vl); - vint8m2_t q8_y2 = __riscv_vle8_v_i8m2(q8+32, vl); - - // compute mask for addition - vint8m2_t q5_a = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vand_vx_u8m2(q5_x, 0x0F, vl)); - vuint8m2_t qh_m1 = __riscv_vand_vx_u8m2(vqh, m, vl); - vbool4_t vmask_1 = __riscv_vmsne_vx_u8m2_b4(qh_m1, 0, vl); - vint8m2_t q5_m1 = __riscv_vadd_vx_i8m2_mu(vmask_1, q5_a, q5_a, 16, vl); - m <<= 1; - - vint8m2_t q5_l = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vsrl_vx_u8m2(q5_x, 0x04, vl)); - vuint8m2_t qh_m2 = __riscv_vand_vx_u8m2(vqh, m, vl); - vbool4_t vmask_2 = __riscv_vmsne_vx_u8m2_b4(qh_m2, 0, vl); - vint8m2_t q5_m2 = __riscv_vadd_vx_i8m2_mu(vmask_2, q5_l, q5_l, 16, vl); - m <<= 1; - - vint16m4_t v0 = __riscv_vwmul_vv_i16m4(q5_m1, q8_y1, vl); - vint16m4_t v1 = __riscv_vwmul_vv_i16m4(q5_m2, q8_y2, vl); - - vint32m8_t vs1 = __riscv_vwmul_vx_i32m8(v0, scales[is++], vl); - vint32m8_t vs2 = __riscv_vwmul_vx_i32m8(v1, scales[is++], vl); - - vint32m1_t vacc1 = __riscv_vredsum_vs_i32m8_i32m1(vs1, vzero, vl); - vint32m1_t vacc2 = __riscv_vredsum_vs_i32m8_i32m1(vs2, vacc1, vl); - - aux32 += __riscv_vmv_x_s_i32m1_i32(vacc2); - q5 += 32; q8 += 64; - - } - - sums += aux32 * d; - - } - - *s = sumf+sums; - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector signed char lowMask1 = vec_splats((int8_t)0x3f); - const vector signed char lowMask2 = vec_splats((int8_t)0x30); - const vector int v0 = vec_splats((int32_t)0); - const vector unsigned char v1 = vec_splats((unsigned char)0x1); - const vector unsigned char v2 = vec_splats((unsigned char)0x2); - const vector unsigned char v3 = vec_splats((unsigned char)0x3); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector float vxmin = vec_splats(GGML_FP16_TO_FP32(x[i].dmin)); - vector float vdmin = vec_mul(vxmin, vyd); - - UNUSED(kmask1); - UNUSED(kmask2); - UNUSED(kmask3); - UNUSED(utmp); - - vector signed char u0 = (vector signed char)vec_xl_len(x[i].scales, 8); - vector signed char u1 = vec_and(vec_sr(u0, v2), lowMask2); - vector signed char u2 = (vector signed char)vec_xl_len(x[i].scales + 8, 4); - vector signed char u3 = vec_sr(u2, v4); - - vector signed char u30 = u1; - vector signed char u31 = (vector signed char)vec_mergeh((vector signed int)vec_and(u2, lowMask), (vector signed int)u3); - - u1 = vec_and(u0, lowMask1); - u2 = vec_or(u30, u31); - - vector signed char utmps = (vector signed char)vec_mergeh((vector signed int)u1, (vector signed int)u2); - - vector signed short q8ysums0 = vec_xl( 0, y[i].bsums); - vector signed short q8ysums1 = vec_xl(16, y[i].bsums); - - vector signed short vscales = vec_unpackh(utmps); - - vector signed short q5xmins = vec_unpackl(utmps); - vector signed short q5xmins0 = vec_mergeh(q5xmins, q5xmins); - vector signed short q5xmins1 = vec_mergel(q5xmins, q5xmins); - - vector signed int prod0 = vec_mule(q5xmins0, q8ysums0); - vector signed int prod1 = vec_mule(q5xmins1, q8ysums1); - vector signed int prod2 = vec_mulo(q5xmins0, q8ysums0); - vector signed int prod3 = vec_mulo(q5xmins1, q8ysums1); - - vsumf0 = vec_nmsub(vec_ctf(prod0, 0), vdmin, vsumf0); - vsumf1 = vec_nmsub(vec_ctf(prod1, 0), vdmin, vsumf1); - vsumf2 = vec_nmsub(vec_ctf(prod2, 0), vdmin, vsumf2); - vsumf3 = vec_nmsub(vec_ctf(prod3, 0), vdmin, vsumf3); - - vector signed char qxhs0 = (vector signed char)vec_xl( 0, x[i].qh); - vector signed char qxhs1 = (vector signed char)vec_xl(16, x[i].qh); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - - const uint8_t * GGML_RESTRICT q5 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - for (int j = 0; j < QK_K/64; ++j) { - __builtin_prefetch(q5, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector signed char qxs0 = (vector signed char)vec_xl( 0, q5); - vector signed char qxs1 = (vector signed char)vec_xl(16, q5); - q5 += 32; - - vector signed char qxs00 = vec_and(qxs0, lowMask); - vector signed char qxs01 = vec_sr(qxs0, v4); - vector signed char qxs10 = vec_and(qxs1, lowMask); - vector signed char qxs11 = vec_sr(qxs1, v4); - - vector signed char q5h00 = vec_sl(vec_and((vector signed char)v1, qxhs0), v4); - vector signed char q5h01 = vec_sl(vec_and((vector signed char)v2, qxhs0), v3); - vector signed char q5h10 = vec_sl(vec_and((vector signed char)v1, qxhs1), v4); - vector signed char q5h11 = vec_sl(vec_and((vector signed char)v2, qxhs1), v3); - qxhs0 = vec_sr(qxhs0, v2); - qxhs1 = vec_sr(qxhs1, v2); - - vector unsigned char q5x00 = (vector unsigned char)vec_or(q5h00, qxs00); - vector unsigned char q5x01 = (vector unsigned char)vec_or(q5h01, qxs01); - vector unsigned char q5x10 = (vector unsigned char)vec_or(q5h10, qxs10); - vector unsigned char q5x11 = (vector unsigned char)vec_or(q5h11, qxs11); - - vector signed char q8y00 = vec_xl( 0, q8); - vector signed char q8y10 = vec_xl(16, q8); - vector signed char q8y01 = vec_xl(32, q8); - vector signed char q8y11 = vec_xl(48, q8); - q8 += 64; - - vector signed int qv00 = vec_msum(q8y00, q5x00, v0); - vector signed int qv01 = vec_msum(q8y01, q5x01, v0); - vector signed int qv10 = vec_msum(q8y10, q5x10, v0); - vector signed int qv11 = vec_msum(q8y11, q5x11, v0); - - vector signed int vscales_h = vec_unpackh(vscales); - vector signed int vs0 = vec_splat(vscales_h, 0); - vector signed int vs1 = vec_splat(vscales_h, 1); - vscales = vec_sld(vscales, vscales, 12); - - vsumi0 = vec_add(vec_mul(qv00, vs0), vsumi0); - vsumi1 = vec_add(vec_mul(qv10, vs0), vsumi1); - vsumi2 = vec_add(vec_mul(qv01, vs1), vsumi2); - vsumi3 = vec_add(vec_mul(qv11, vs1), vsumi3); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#elif defined __loongarch_asx - - __m256 acc = (__m256)__lasx_xvldi(0); - __m128 acc_m = (__m128)__lsx_vldi(0); - - for (int i = 0; i < nb; ++i) { - - const uint8_t * GGML_RESTRICT q5 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]); - const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128); - const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0); - - const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0); - const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1)); - const __m128i prod = lsx_madd_h(mins128, q8s); - acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m); - - const __m256i scales = lasx_insertf128(scales128, scales128); - - const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0); - - __m256i sumi = __lasx_xvldi(0); - - for (int j = 0; j < QK_K/64; ++j) { - - const __m256i scale_0 = lasx_xvrepl128vei_h(scales, 2 * j + 0); - const __m256i scale_1 = lasx_xvrepl128vei_h(scales, 2 * j + 1); - - const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32; - - const __m256i q5l_0 = __lasx_xvandi_b(q5bits, 0xf); - const __m256i q5l_1 = __lasx_xvsrli_b(q5bits, 4); - const __m256i q5h_0 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 0), 0), 0xef); - const __m256i q5h_1 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 1), 0), 0xef); - const __m256i q5_0 = __lasx_xvor_v(q5l_0, q5h_0); - const __m256i q5_1 = __lasx_xvor_v(q5l_1, q5h_1); - - const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - - __m256i p16_0 = lasx_madd_h_b(q5_0, q8_0); - __m256i p16_1 = lasx_madd_h_b(q5_1, q8_1); - - p16_0 = lasx_madd_h(scale_0, p16_0); - p16_1 = lasx_madd_h(scale_1, p16_1); - - sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1)); - - } - - __m256 vd = __lasx_xvreplfr2vr_s(d); - acc = __lasx_xvfmadd_s(vd, __lasx_xvffint_s_w(sumi), acc); - - } - - acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 8)); - acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 4)); - - *s = hsum_float_8(acc) + ((v4f32)acc_m)[0]; -#elif defined(__VXE__) || defined(__VXE2__) - const uint8x16_t v_lm = vec_splat_u8(0x0F); - const uint8x16_t v_1m = vec_splat_u8(0x01); - const uint8x16_t v_2m = vec_splat_u8(0x02); - - const int32x4_t v_z = vec_splat_s32(0); - - const uchar8x16_t v_minsm = { - 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, - 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF - }; - - int8x16_t q5b[4]; - uint8x16_t q5h[4]; - - uint8x16_t v_xl[2]; - uint8x16_t v_xh[2]; - int8x16_t v_y[4]; - - float sumf = 0; - - for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); - - const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums); - const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums); - const int16x8_t v_ysums = vec_padd_s16(v_ysumsl, v_ysumsh); - - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - const uint8x16_t v_mins16 = vec_xl(0, (const uint8_t *)utmp); - const uint8x16_t v_mins8 = vec_perm(v_mins16, v_mins16, v_minsm); - const int16x8_t v_minsh = (int16x8_t)vec_unpackh(v_mins8); - - const int32x4_t v_minsho = vec_mulo(v_ysums, v_minsh); - const int32x4_t v_minshe = vec_mule(v_ysums, v_minsh); - const int32x4_t v_mins = vec_add(v_minsho, v_minshe); - const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]; - - const uint8_t * scales = (const uint8_t *)utmp; - const uint8_t * GGML_RESTRICT x0l = x[i].qs; - const uint8_t * GGML_RESTRICT x0h = x[i].qh; - const int8_t * GGML_RESTRICT y0 = y[i].qs; - - v_xh[0] = vec_xl(0 , x0h); - v_xh[1] = vec_xl(16, x0h); - - int32_t sumi = 0; - for (int j = 0; j < QK_K/64; ++j) { - v_xl[0] = vec_xl(0 , x0l); - v_xl[1] = vec_xl(16, x0l); - x0l += 32; - - v_y[0] = vec_xl(0 , y0); - v_y[1] = vec_xl(16, y0); - v_y[2] = vec_xl(32, y0); - v_y[3] = vec_xl(48, y0); - y0 += 64; - - q5h[0] = vec_sl(vec_and(v_1m, v_xh[0]), 4); - q5h[1] = vec_sl(vec_and(v_1m, v_xh[1]), 4); - q5h[2] = vec_sl(vec_and(v_2m, v_xh[0]), 3); - q5h[3] = vec_sl(vec_and(v_2m, v_xh[1]), 3); - v_xh[0] = vec_sr(v_xh[0], 2); - v_xh[1] = vec_sr(v_xh[1], 2); - - q5b[0] = (int8x16_t)vec_or(vec_and(v_xl[0], v_lm), q5h[0]); - q5b[1] = (int8x16_t)vec_or(vec_and(v_xl[1], v_lm), q5h[1]); - q5b[2] = (int8x16_t)vec_or(vec_sr(v_xl[0], 4), q5h[2]); - q5b[3] = (int8x16_t)vec_or(vec_sr(v_xl[1], 4), q5h[3]); - - int32x4_t sumi0 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[0], v_y[0]), q5b[1], v_y[1]); - int32x4_t sumi1 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[2], v_y[2]), q5b[3], v_y[3]); - - sumi += (sumi0[0] + sumi0[1] + sumi0[2] + sumi0[3]) * *scales++; - sumi += (sumi1[0] + sumi1[1] + sumi1[2] + sumi1[3]) * *scales++; - } - - sumf += d * sumi - dmin * mins; - } - - *s = sumf; -#else - - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; - - int8_t aux8[QK_K]; - int16_t aux16[8]; - float sums [8]; - int32_t aux32[8]; - memset(sums, 0, 8*sizeof(float)); - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * GGML_RESTRICT q4 = x[i].qs; - const uint8_t * GGML_RESTRICT hm = x[i].qh; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - memset(aux32, 0, 8*sizeof(int32_t)); - int8_t * GGML_RESTRICT a = aux8; - uint8_t m = 1; - for (int j = 0; j < QK_K/64; ++j) { - for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); - for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); - a += 32; m <<= 1; - for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); - for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); - a += 32; m <<= 1; - q4 += 32; - } - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; - - int sumi = 0; - for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; - a = aux8; - int is = 0; - for (int j = 0; j < QK_K/32; ++j) { - int32_t scale = scales[is++]; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - } - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; - const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; - sumf -= dmin * sumi; - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; -#endif -} - -void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q6_K * GGML_RESTRICT x = vx; - const block_q8_K * GGML_RESTRICT y = vy; - - const int nb = n / QK_K; - -#ifdef __ARM_FEATURE_SVE - const int vector_length = ggml_cpu_get_sve_cnt()*8; - float sum = 0; - svuint8_t m4b = svdup_n_u8(0xf); - svint32_t vzero = svdup_n_s32(0); - svuint8_t mone = svdup_n_u8(0x30); - svint8_t q6bytes_1, q6bytes_2, q6bytes_3, q6bytes_4; - svuint8_t q6h_1, q6h_2, q6h_3, q6h_4; - - for (int i = 0; i < nb; ++i) { - const float d_all = GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * GGML_RESTRICT q6 = x[i].ql; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - const int8_t * GGML_RESTRICT scale = x[i].scales; - - const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8); - const svint16_t q8sums_1 = svld1_s16(pg16_8, y[i].bsums); - const svint16_t q8sums_2 = svld1_s16(pg16_8, y[i].bsums + 8); - const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale)); - const svint16_t q6scales_2 = svunpklo_s16(svld1_s8(svptrue_pat_b8(SV_VL8), scale + 8)); - const svint64_t prod = svdup_n_s64(0); - int32_t isum_mins = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(prod, q8sums_1, q6scales_1), - svdot_s64(prod, q8sums_2, q6scales_2))); - int32_t isum = 0; - - switch (vector_length) { - case 128: - { - const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4); - const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16); - svint32_t isum_tmp = svdup_n_s32(0); - for (int j = 0; j < QK_K/128; ++j) { - svuint8_t qhbits_1 = svld1_u8(pg8_16, qh); - svuint8_t qhbits_2 = svld1_u8(pg8_16, qh+16); - qh += 32; - svuint8_t q6bits_1 = svld1_u8(pg8_16, q6); - svuint8_t q6bits_2 = svld1_u8(pg8_16, q6+16); - svuint8_t q6bits_3 = svld1_u8(pg8_16, q6+32); - svuint8_t q6bits_4 = svld1_u8(pg8_16, q6+48); - q6 += 64; - svint8_t q8bytes_1 = svld1_s8(pg8_16, q8); - svint8_t q8bytes_2 = svld1_s8(pg8_16, q8+16); - svint8_t q8bytes_3 = svld1_s8(pg8_16, q8+32); - svint8_t q8bytes_4 = svld1_s8(pg8_16, q8+48); - q8 += 64; - - q6h_1 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 4)); - q6h_2 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 4)); - q6h_3 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_1, 2)); - q6h_4 = svand_u8_x(pg16_8, mone, svlsl_n_u8_x(pg16_8, qhbits_2, 2)); - q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_1, m4b), q6h_1)); - q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_2, m4b), q6h_2)); - q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_3, m4b), q6h_3)); - q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svand_u8_x(pg8_16, q6bits_4, m4b), q6h_4)); - isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]); - isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]); - isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]); - isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]); - - scale += 4; - q8bytes_1 = svld1_s8(pg8_16, q8); - q8bytes_2 = svld1_s8(pg8_16, q8+16); - q8bytes_3 = svld1_s8(pg8_16, q8+32); - q8bytes_4 = svld1_s8(pg8_16, q8+48); - q8 += 64; - - q6h_1 = svand_u8_x(pg16_8, mone, qhbits_1); - q6h_2 = svand_u8_x(pg16_8, mone, qhbits_2); - q6h_3 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_1, 2)); - q6h_4 = svand_u8_x(pg16_8, mone, svlsr_n_u8_x(pg16_8, qhbits_2, 2)); - q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_1, 4), q6h_1)); - q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_2, 4), q6h_2)); - q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_3, 4), q6h_3)); - q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_16, svlsr_n_u8_x(pg8_16, q6bits_4, 4), q6h_4)); - isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale[0]); - isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale[1]); - isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale[2]); - isum_tmp = svmla_n_s32_x(pg32_4, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale[3]); - scale += 4; - } - isum += svaddv_s32(pg32_4, isum_tmp); - sum += d_all * y[i].d * (isum - 32 * isum_mins); - } - break; - case 256: - case 512: - { - const svbool_t pg8_2 = svptrue_pat_b8(SV_VL2); - const svbool_t pg32_8 = svptrue_pat_b32(SV_VL8); - const svbool_t pg8_32 = svptrue_pat_b8(SV_VL32); - svint32_t isum_tmp = svdup_n_s32(0); - for (int j = 0; j < QK_K/128; j++) { - svuint8_t qhbits_1 = svld1_u8(pg8_32, qh); - qh += 32; - svuint8_t q6bits_1 = svld1_u8(pg8_32, q6); - svuint8_t q6bits_2 = svld1_u8(pg8_32, q6+32); - q6 += 64; - svint8_t q8bytes_1 = svld1_s8(pg8_32, q8); - svint8_t q8bytes_2 = svld1_s8(pg8_32, q8+32); - svint8_t q8bytes_3 = svld1_s8(pg8_32, q8+64); - svint8_t q8bytes_4 = svld1_s8(pg8_32, q8+96); - q8 += 128; - q6h_1 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 4)); - q6h_2 = svand_u8_x(pg8_32, mone, svlsl_n_u8_x(pg8_32, qhbits_1, 2)); - q6h_3 = svand_u8_x(pg8_32, mone, qhbits_1); - q6h_4 = svand_u8_x(pg8_32, mone, svlsr_n_u8_x(pg8_32, qhbits_1, 2)); - q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_1, m4b), q6h_1)); - q6bytes_2 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svand_u8_x(pg8_32, q6bits_2, m4b), q6h_2)); - q6bytes_3 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_1, 4), q6h_3)); - q6bytes_4 = svreinterpret_s8_u8(svorr_u8_x(pg8_32, svlsr_n_u8_x(pg8_32, q6bits_2, 4), q6h_4)); - - svint8_t scale_lane_1_tmp = svld1_s8(pg8_2, scale); - scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp); - scale_lane_1_tmp= svzip1_s8(scale_lane_1_tmp, scale_lane_1_tmp); - svint8_t scale_lane_2_tmp = svld1_s8(pg8_2, scale+2); - scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp); - scale_lane_2_tmp = svzip1_s8(scale_lane_2_tmp, scale_lane_2_tmp); - svint8_t scale_lane_3_tmp = svld1_s8(pg8_2, scale+4); - scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp); - scale_lane_3_tmp = svzip1_s8(scale_lane_3_tmp, scale_lane_3_tmp); - svint8_t scale_lane_4_tmp = svld1_s8(pg8_2, scale+6); - scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp); - scale_lane_4_tmp = svzip1_s8(scale_lane_4_tmp, scale_lane_4_tmp); - svint32_t scale_lane_1 = svunpklo_s32(svunpklo_s16(scale_lane_1_tmp)); - svint32_t scale_lane_2 = svunpklo_s32(svunpklo_s16(scale_lane_2_tmp)); - svint32_t scale_lane_3 = svunpklo_s32(svunpklo_s16(scale_lane_3_tmp)); - svint32_t scale_lane_4 = svunpklo_s32(svunpklo_s16(scale_lane_4_tmp)); - - isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_1, q8bytes_1), scale_lane_1); - isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_2, q8bytes_2), scale_lane_2); - isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_3, q8bytes_3), scale_lane_3); - isum_tmp = svmla_s32_x(pg32_8, isum_tmp, svdot_s32(vzero, q6bytes_4, q8bytes_4), scale_lane_4); - scale += 8; - } - isum += svaddv_s32(pg32_8, isum_tmp); - sum += d_all * y[i].d * (isum - 32 * isum_mins); - } - break; - default: - assert(false && "Unsupported vector length"); - break; - } - } - - *s = sum; - -#elif __ARM_NEON - float sum = 0; - - const uint8x16_t m4b = vdupq_n_u8(0xF); - const int32x4_t vzero = vdupq_n_s32(0); - //const int8x16_t m32s = vdupq_n_s8(32); - - const uint8x16_t mone = vdupq_n_u8(3); - - ggml_int8x16x4_t q6bytes; - ggml_uint8x16x4_t q6h; - - for (int i = 0; i < nb; ++i) { - - const float d_all = GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * GGML_RESTRICT q6 = x[i].ql; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - const int8_t * GGML_RESTRICT scale = x[i].scales; - - const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums); - const int8x16_t scales = vld1q_s8(scale); - const ggml_int16x8x2_t q6scales = {{vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}}; - - const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])), - vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))), - vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])), - vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1])))); - int32_t isum_mins = vaddvq_s32(prod); - - int32_t isum = 0; - - for (int j = 0; j < QK_K/128; ++j) { - - ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32; - ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64; - ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; - - q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); - q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); - uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2); - q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[1], 2); - q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - - //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); - //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); - //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s); - //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s); - q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])); - q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])); - q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])); - q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])); - - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + - vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + - vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + - vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; - - scale += 4; - - q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64; - - shifted = vshrq_n_u8(qhbits.val[0], 4); - q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[1], 4); - q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[0], 6); - q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - shifted = vshrq_n_u8(qhbits.val[1], 6); - q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); - - //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s); - //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s); - //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s); - //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s); - q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])); - q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])); - q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])); - q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])); - - isum += vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + - vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + - vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + - vaddvq_s32(ggml_vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; - scale += 4; - } - //sum += isum * d_all * y[i].d; - sum += d_all * y[i].d * (isum - 32 * isum_mins); - - } - *s = sum; - -#elif defined __AVX2__ - - const __m256i m4 = _mm256_set1_epi8(0xF); - const __m256i m2 = _mm256_set1_epi8(3); - const __m256i m32s = _mm256_set1_epi8(32); - - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * GGML_RESTRICT q4 = x[i].ql; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); - - __m256i sumi = _mm256_setzero_si256(); - - int is = 0; - - for (int j = 0; j < QK_K/128; ++j) { - - const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); - const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); - const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); - const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); - is += 4; - - const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; - const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; - const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32; - - const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4); - const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4); - const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4); - const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4); - - const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); - const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1); - const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2); - const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3); - - const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - - __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); - __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); - __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2); - __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3); - - __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); - __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); - __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2); - __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3); - - p16_0 = _mm256_sub_epi16(p16_0, q8s_0); - p16_1 = _mm256_sub_epi16(p16_1, q8s_1); - p16_2 = _mm256_sub_epi16(p16_2, q8s_2); - p16_3 = _mm256_sub_epi16(p16_3, q8s_3); - - p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); - p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); - p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2); - p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3); - - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3)); - - } - - acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); - } - - *s = hsum_float_8(acc); - -#elif defined __AVX__ - - const __m128i m3 = _mm_set1_epi8(3); - const __m128i m15 = _mm_set1_epi8(15); - - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * GGML_RESTRICT q4 = x[i].ql; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - // handle the q6_k -32 offset separately using bsums - const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)y[i].bsums); - const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)y[i].bsums + 1); - const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); - const __m128i scales_16_0 = _mm_cvtepi8_epi16(scales); - const __m128i scales_16_1 = _mm_cvtepi8_epi16(_mm_bsrli_si128(scales, 8)); - const __m128i q8sclsub_0 = _mm_slli_epi32(_mm_madd_epi16(q8sums_0, scales_16_0), 5); - const __m128i q8sclsub_1 = _mm_slli_epi32(_mm_madd_epi16(q8sums_1, scales_16_1), 5); - - __m128i sumi_0 = _mm_setzero_si128(); - __m128i sumi_1 = _mm_setzero_si128(); - - int is = 0; - - for (int j = 0; j < QK_K/128; ++j) { - - const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16; - const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16; - - const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4); - const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4); - const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(12)), 2); - const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(12)), 2); - const __m128i q4h_4 = _mm_and_si128(q4bitsH_0, _mm_set1_epi8(48)); - const __m128i q4h_5 = _mm_and_si128(q4bitsH_1, _mm_set1_epi8(48)); - const __m128i q4h_6 = _mm_srli_epi16(_mm_and_si128(q4bitsH_0, _mm_set1_epi8(-64)), 2); - const __m128i q4h_7 = _mm_srli_epi16(_mm_and_si128(q4bitsH_1, _mm_set1_epi8(-64)), 2); - - const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; - - const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m15), q4h_0); - const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m15), q4h_1); - const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m15), q4h_2); - const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m15), q4h_3); - const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m15), q4h_4); - const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m15), q4h_5); - const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m15), q4h_6); - const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m15), q4h_7); - - const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; - - __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0); - __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1); - __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2); - __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3); - __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4); - __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5); - __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6); - __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7); - - const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); - const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); - const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); - const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); - is += 4; - - p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); - p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_0, 8)), p16_1); - p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); - p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_1, 8)), p16_3); - p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4); - p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_2, 8)), p16_5); - p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6); - p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_bsrli_si128(scale_3, 8)), p16_7); - - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); - sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6)); - sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7)); - - } - - sumi_0 = _mm_sub_epi32(sumi_0, q8sclsub_0); - sumi_1 = _mm_sub_epi32(sumi_1, q8sclsub_1); - const __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); - acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi)), acc); - } - - *s = hsum_float_8(acc); - -#elif defined __wasm_simd128__ - int8_t aux8[QK_K] __attribute__((aligned(16))); - int32_t aux32[8] __attribute__((aligned(16))) = {0}; - float sums[8] __attribute__((aligned(16))) = {0}; - - for (int i = 0; i < nb; ++i) { - // Unpack 6-bit quantized data into aux8 (unchanged) - const uint8_t * GGML_RESTRICT q4 = x[i].ql; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - int8_t * a = aux8; - for (int j = 0; j < QK_K; j += 128) { - for (int l = 0; l < 32; ++l) { - a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; - a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; - a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; - a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; - } - a += 128; - q4 += 64; - qh += 32; - } - - const int8_t * GGML_RESTRICT a_ptr = aux8; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - v128_t acc0 = wasm_i32x4_splat(0); - v128_t acc1 = wasm_i32x4_splat(0); - - for (int j = 0; j < QK_K/16; ++j) { - const int scale = x[i].scales[j]; - const v128_t vscale = wasm_i32x4_splat(scale); - - // Load 16 elements from a and q8 - const v128_t a_vec = wasm_v128_load(a_ptr); - const v128_t q8_vec = wasm_v128_load(q8); - - // Process low 8 elements - v128_t a_low = wasm_i16x8_extend_low_i8x16(a_vec); - v128_t q8_low = wasm_i16x8_extend_low_i8x16(q8_vec); - v128_t prod_low = wasm_i16x8_mul(a_low, q8_low); - v128_t prod_lo_lo = wasm_i32x4_extend_low_i16x8(prod_low); - v128_t prod_lo_hi = wasm_i32x4_extend_high_i16x8(prod_low); - - // Process high 8 elements - v128_t a_high = wasm_i16x8_extend_high_i8x16(a_vec); - v128_t q8_high = wasm_i16x8_extend_high_i8x16(q8_vec); - v128_t prod_high = wasm_i16x8_mul(a_high, q8_high); - v128_t prod_hi_lo = wasm_i32x4_extend_low_i16x8(prod_high); - v128_t prod_hi_hi = wasm_i32x4_extend_high_i16x8(prod_high); - - // Scale and accumulate - prod_lo_lo = wasm_i32x4_mul(prod_lo_lo, vscale); - prod_lo_hi = wasm_i32x4_mul(prod_lo_hi, vscale); - prod_hi_lo = wasm_i32x4_mul(prod_hi_lo, vscale); - prod_hi_hi = wasm_i32x4_mul(prod_hi_hi, vscale); - - acc0 = wasm_i32x4_add(acc0, wasm_i32x4_add(prod_lo_lo, prod_hi_lo)); - acc1 = wasm_i32x4_add(acc1, wasm_i32x4_add(prod_lo_hi, prod_hi_hi)); - - a_ptr += 16; - q8 += 16; - } - - // Store accumulated results - wasm_v128_store(&aux32[0], acc0); - wasm_v128_store(&aux32[4], acc1); - - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) { - sums[l] += d * aux32[l]; - } - } - - // Sum final results - float sumf = 0; - for (int l = 0; l < 8; ++l) { - sumf += sums[l]; - } - *s = sumf; - -#elif defined __riscv_v_intrinsic - - const int vector_length = __riscv_vlenb() * 8; - float sumf = 0; - - switch (vector_length) { - case 256: - for (int i = 0; i < nb; ++i) { - - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - - const uint8_t * GGML_RESTRICT q6 = x[i].ql; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - const int8_t * GGML_RESTRICT scale = x[i].scales; - - size_t vl; - - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - - int sum_t = 0; - int is = 0; - - for (int j = 0; j < QK_K/128; ++j) { - - vl = 32; - - // load qh - vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); - - // load Q6 - vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); - vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); - - vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); - vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); - vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); - vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); - - vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); - vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); - vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); - vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); - - vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); - vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); - vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); - vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); - - vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); - vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); - vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); - vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); - - // load Q8 and take product - vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); - - vl = 16; - - vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); - vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); - vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); - vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); - vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); - vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); - vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); - vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); - - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); - - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); - - q6 += 64; qh += 32; q8 += 128; is=8; - - } - - sumf += d * sum_t; - - } - break; - case 128: - for (int i = 0; i < nb; ++i) { - - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - - const uint8_t * restrict q6 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const int8_t * restrict scale = x[i].scales; - - int sum_t = 0; - int t0; - - for (int j = 0; j < QK_K/128; ++j) { - __asm__ __volatile__( - "vsetvli zero, %[vl32], e8, m2\n\t" - "vle8.v v4, (%[qh])\n\t" - "vsll.vi v0, v4, 4\n\t" - "vsll.vi v2, v4, 2\n\t" - "vsrl.vi v6, v4, 2\n\t" - "vsetvli zero, %[vl64], e8, m4\n\t" - "vle8.v v8, (%[q6])\n\t" - "vsrl.vi v12, v8, 4\n\t" - "vand.vi v8, v8, 0xF\n\t" - "vsetvli zero, %[vl128], e8, m8\n\t" - "vand.vx v0, v0, %[mask]\n\t" - "vor.vv v8, v8, v0\n\t" - "vle8.v v0, (%[q8])\n\t" - "vsub.vx v8, v8, %[vl32]\n\t" - "vsetvli zero, %[vl64], e8, m4\n\t" - "vwmul.vv v16, v0, v8\n\t" - "vwmul.vv v24, v4, v12\n\t" - "vsetivli zero, 16, e16, m2\n\t" - "vmv.v.x v0, zero\n\t" - "vwredsum.vs v10, v16, v0\n\t" - "vwredsum.vs v9, v18, v0\n\t" - "vwredsum.vs v8, v20, v0\n\t" - "vwredsum.vs v7, v22, v0\n\t" - "vwredsum.vs v11, v24, v0\n\t" - "vwredsum.vs v12, v26, v0\n\t" - "vwredsum.vs v13, v28, v0\n\t" - "vwredsum.vs v14, v30, v0\n\t" - "vsetivli zero, 4, e32, m1\n\t" - "vslideup.vi v10, v9, 1\n\t" - "vslideup.vi v8, v7, 1\n\t" - "vslideup.vi v11, v12, 1\n\t" - "vslideup.vi v13, v14, 1\n\t" - "vslideup.vi v10, v8, 2\n\t" - "vslideup.vi v11, v13, 2\n\t" - "vsetivli zero, 8, e32, m2\n\t" - "vle8.v v2, (%[scale])\n\t" - "vsext.vf4 v4, v2\n\t" - "vmul.vv v2, v4, v10\n\t" - "vredsum.vs v0, v2, v0\n\t" - "vmv.x.s %[t0], v0\n\t" - "add %[sumi], %[sumi], %[t0]" - : [sumi] "+&r" (sum_t), [t0] "=&r" (t0) - : [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale) - , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) - , [mask] "r" (0x30) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); - q6 += 64; qh += 32; q8 += 128; scale += 8; - } - - sumf += d * sum_t; - - } - break; - default: - assert(false && "Unsupported vector length"); - break; - } - - *s = sumf; - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector int v0 = vec_splats((int32_t)0); - const vector unsigned char v2 = vec_splats((unsigned char)0x2); - const vector unsigned char v3 = vec_splats((unsigned char)0x3); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - const vector unsigned char v6 = vec_splats((unsigned char)0x6); - const vector signed char off = vec_splats((signed char)0x20); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - vector signed int vsumi4 = v0; - vector signed int vsumi5 = v0; - vector signed int vsumi6 = v0; - vector signed int vsumi7 = v0; - - const uint8_t * GGML_RESTRICT q6 = x[i].ql; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const int8_t * GGML_RESTRICT qs = x[i].scales; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - for (int j = 0; j < QK_K/128; ++j) { - __builtin_prefetch(q6, 0, 0); - __builtin_prefetch(qh, 0, 0); - __builtin_prefetch(q8, 0, 0); - - vector signed char qxs0 = (vector signed char)vec_xl( 0, q6); - vector signed char qxs1 = (vector signed char)vec_xl(16, q6); - vector signed char qxs2 = (vector signed char)vec_xl(32, q6); - vector signed char qxs3 = (vector signed char)vec_xl(48, q6); - q6 += 64; - - vector signed char qxs00 = vec_and(qxs0, lowMask); - vector signed char qxs01 = vec_sr(qxs0, v4); - vector signed char qxs10 = vec_and(qxs1, lowMask); - vector signed char qxs11 = vec_sr(qxs1, v4); - vector signed char qxs20 = vec_and(qxs2, lowMask); - vector signed char qxs21 = vec_sr(qxs2, v4); - vector signed char qxs30 = vec_and(qxs3, lowMask); - vector signed char qxs31 = vec_sr(qxs3, v4); - - vector signed char qxhs0 = (vector signed char)vec_xl( 0, qh); - vector signed char qxhs1 = (vector signed char)vec_xl(16, qh); - qh += 32; - - vector signed char qxh00 = vec_sl(vec_and((vector signed char)v3, qxhs0), v4); - vector signed char qxh01 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v4)), v4); - vector signed char qxh10 = vec_sl(vec_and((vector signed char)v3, qxhs1), v4); - vector signed char qxh11 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v4)), v4); - vector signed char qxh20 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v2)), v4); - vector signed char qxh21 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs0, v6)), v4); - vector signed char qxh30 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v2)), v4); - vector signed char qxh31 = vec_sl(vec_and((vector signed char)v3, vec_sr(qxhs1, v6)), v4); - - vector signed char q6x00 = vec_sub(vec_or(qxh00, qxs00), off); - vector signed char q6x01 = vec_sub(vec_or(qxh01, qxs01), off); - vector signed char q6x10 = vec_sub(vec_or(qxh10, qxs10), off); - vector signed char q6x11 = vec_sub(vec_or(qxh11, qxs11), off); - vector signed char q6x20 = vec_sub(vec_or(qxh20, qxs20), off); - vector signed char q6x21 = vec_sub(vec_or(qxh21, qxs21), off); - vector signed char q6x30 = vec_sub(vec_or(qxh30, qxs30), off); - vector signed char q6x31 = vec_sub(vec_or(qxh31, qxs31), off); - - vector signed char q8y00 = vec_xl( 0, q8); - vector signed char q8y10 = vec_xl( 16, q8); - vector signed char q8y20 = vec_xl( 32, q8); - vector signed char q8y30 = vec_xl( 48, q8); - vector signed char q8y01 = vec_xl( 64, q8); - vector signed char q8y11 = vec_xl( 80, q8); - vector signed char q8y21 = vec_xl( 96, q8); - vector signed char q8y31 = vec_xl(112, q8); - q8 += 128; - - vector signed short qv00 = vec_add(vec_mule(q6x00, q8y00), vec_mulo(q6x00, q8y00)); - vector signed short qv10 = vec_add(vec_mule(q6x10, q8y10), vec_mulo(q6x10, q8y10)); - vector signed short qv20 = vec_add(vec_mule(q6x20, q8y20), vec_mulo(q6x20, q8y20)); - vector signed short qv30 = vec_add(vec_mule(q6x30, q8y30), vec_mulo(q6x30, q8y30)); - vector signed short qv01 = vec_add(vec_mule(q6x01, q8y01), vec_mulo(q6x01, q8y01)); - vector signed short qv11 = vec_add(vec_mule(q6x11, q8y11), vec_mulo(q6x11, q8y11)); - vector signed short qv21 = vec_add(vec_mule(q6x21, q8y21), vec_mulo(q6x21, q8y21)); - vector signed short qv31 = vec_add(vec_mule(q6x31, q8y31), vec_mulo(q6x31, q8y31)); - - vector signed short vscales = vec_unpackh(vec_xl_len(qs, 8)); - qs += 8; - - vector signed short vs0 = vec_splat(vscales, 0); - vector signed short vs1 = vec_splat(vscales, 1); - vector signed short vs2 = vec_splat(vscales, 2); - vector signed short vs3 = vec_splat(vscales, 3); - vector signed short vs4 = vec_splat(vscales, 4); - vector signed short vs5 = vec_splat(vscales, 5); - vector signed short vs6 = vec_splat(vscales, 6); - vector signed short vs7 = vec_splat(vscales, 7); - - vsumi0 = vec_msum(qv00, vs0, vsumi0); - vsumi1 = vec_msum(qv01, vs4, vsumi1); - vsumi2 = vec_msum(qv10, vs1, vsumi2); - vsumi3 = vec_msum(qv11, vs5, vsumi3); - vsumi4 = vec_msum(qv20, vs2, vsumi4); - vsumi5 = vec_msum(qv21, vs6, vsumi5); - vsumi6 = vec_msum(qv30, vs3, vsumi6); - vsumi7 = vec_msum(qv31, vs7, vsumi7); - } - - vsumi0 = vec_add(vsumi0, vsumi4); - vsumi1 = vec_add(vsumi1, vsumi5); - vsumi2 = vec_add(vsumi2, vsumi6); - vsumi3 = vec_add(vsumi3, vsumi7); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#elif defined __loongarch_asx - - const __m256i m32s = __lasx_xvreplgr2vr_b(32); - - __m256 acc = (__m256)__lasx_xvldi(0); - - for (int i = 0; i < nb; ++i) { - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * GGML_RESTRICT q4 = x[i].ql; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - const __m128i scales128 = __lsx_vld((const __m128i*)x[i].scales, 0); - const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; - const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask)); - - __m256i sumi = __lasx_xvldi(0); - - for (int j = 0; j < QK_K/128; ++j) { - - const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; - const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32; - const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32; - - const __m256i q4h_0 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3), 4); - const __m256i q4h_1 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3 << 2), 2); - const __m256i q4h_2 = __lasx_xvandi_b(q4bitsH, 3 << 4); - const __m256i q4h_3 = __lasx_xvsrli_b(__lasx_xvandi_b(q4bitsH, 3 << 6), 2); - - const __m256i q4_0 = __lasx_xvor_v(__lasx_xvandi_b(q4bits1, 0xf), q4h_0); - const __m256i q4_1 = __lasx_xvor_v(__lasx_xvandi_b(q4bits2, 0xf), q4h_1); - const __m256i q4_2 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits1, 4), q4h_2); - const __m256i q4_3 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits2, 4), q4h_3); - - const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - - __m256i p16_0 = lasx_madd_h_b(__lasx_xvsub_b(q4_0, m32s), q8_0); - __m256i p16_1 = lasx_madd_h_b(__lasx_xvsub_b(q4_1, m32s), q8_1); - __m256i p16_2 = lasx_madd_h_b(__lasx_xvsub_b(q4_2, m32s), q8_2); - __m256i p16_3 = lasx_madd_h_b(__lasx_xvsub_b(q4_3, m32s), q8_3); - - p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0); - p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1); - p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2); - p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3); - - sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1)); - sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3)); - } - - acc = __lasx_xvfmadd_s((__m256)__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc); - } - - *s = hsum_float_8(acc); -#elif defined(__VXE__) || defined(__VXE2__) - float sum = 0; - - // Lower 4-bit and upper 2-bit masks - const uint8x16_t v_lm = vec_splat_u8(0x0F); - const uint8x16_t v_um = vec_splat_u8(0x03); - - const int32x4_t v_z = vec_splat_s32(0); - - int8x16_t q6b[4]; - uint8x16_t q6h[4]; - - uint8x16_t v_xl[4]; - uint8x16_t v_xh[2]; - int8x16_t v_y[4]; - - for (int i = 0; i < nb; ++i) { - const float d_all = GGML_FP16_TO_FP32(x[i].d); - - const uint8_t * GGML_RESTRICT x0l = x[i].ql; - const uint8_t * GGML_RESTRICT x0h = x[i].qh; - const int8_t * GGML_RESTRICT y0 = y[i].qs; - - const int8_t * GGML_RESTRICT scale = x[i].scales; - - const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums); - const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums); - - const int8x16_t v_scale = vec_xl(0, scale); - const int16x8_t v_scalel = vec_unpackh(v_scale); - const int16x8_t v_scaleh = vec_unpackl(v_scale); - - const int32x4_t v_minslo = vec_mulo(v_ysumsl, v_scalel); - const int32x4_t v_minsle = vec_mule(v_ysumsl, v_scalel); - const int32x4_t v_minsho = vec_mulo(v_ysumsh, v_scaleh); - const int32x4_t v_minshe = vec_mule(v_ysumsh, v_scaleh); - const int32x4_t v_mins = v_minslo + v_minsle + v_minsho + v_minshe; - - const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]; - - int32_t isum = 0; - for (int j = 0; j < QK_K/128; ++j) { - // Load model upper 2 bits - v_xh[0] = vec_xl(0 , x0h); - v_xh[1] = vec_xl(16, x0h); - x0h += 32; - - // Load model lower 4 bits - v_xl[0] = vec_xl(0 , x0l); - v_xl[1] = vec_xl(16, x0l); - v_xl[2] = vec_xl(32, x0l); - v_xl[3] = vec_xl(48, x0l); - x0l += 64; - - // Load activation quants - v_y[0] = vec_xl(0 , y0); - v_y[1] = vec_xl(16, y0); - v_y[2] = vec_xl(32, y0); - v_y[3] = vec_xl(48, y0); - y0 += 64; - - q6h[0] = vec_sl(vec_and(v_um, v_xh[0]), 4); - q6h[1] = vec_sl(vec_and(v_um, v_xh[1]), 4); - uint8x16_t shifted = vec_sr(v_xh[0], 2); - q6h[2] = vec_sl(vec_and(v_um, shifted), 4); - shifted = vec_sr(v_xh[1], 2); - q6h[3] = vec_sl(vec_and(v_um, shifted), 4); - - q6b[0] = (int8x16_t)(vec_or(vec_and(v_xl[0], v_lm), q6h[0])); - q6b[1] = (int8x16_t)(vec_or(vec_and(v_xl[1], v_lm), q6h[1])); - q6b[2] = (int8x16_t)(vec_or(vec_and(v_xl[2], v_lm), q6h[2])); - q6b[3] = (int8x16_t)(vec_or(vec_and(v_xl[3], v_lm), q6h[3])); - - int32x4_t summs0 = ggml_vec_dot(v_z, q6b[0], v_y[0]); - int32x4_t summs1 = ggml_vec_dot(v_z, q6b[1], v_y[1]); - int32x4_t summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]); - int32x4_t summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]); - - isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] + - (summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] + - (summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] + - (summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3]; - - scale += 4; - - - // Load activation quants - v_y[0] = vec_xl(0 , y0); - v_y[1] = vec_xl(16, y0); - v_y[2] = vec_xl(32, y0); - v_y[3] = vec_xl(48, y0); - y0 += 64; - - shifted = vec_sr(v_xh[0], 4); - q6h[0] = vec_sl(vec_and(v_um, shifted), 4); - shifted = vec_sr(v_xh[1], 4); - q6h[1] = vec_sl(vec_and(v_um, shifted), 4); - shifted = vec_sr(v_xh[0], 6); - q6h[2] = vec_sl(vec_and(v_um, shifted), 4); - shifted = vec_sr(v_xh[1], 6); - q6h[3] = vec_sl(vec_and(v_um, shifted), 4); - - q6b[0] = (int8x16_t)(vec_or(vec_sr(v_xl[0], 4), q6h[0])); - q6b[1] = (int8x16_t)(vec_or(vec_sr(v_xl[1], 4), q6h[1])); - q6b[2] = (int8x16_t)(vec_or(vec_sr(v_xl[2], 4), q6h[2])); - q6b[3] = (int8x16_t)(vec_or(vec_sr(v_xl[3], 4), q6h[3])); - - summs0 = ggml_vec_dot(v_z, q6b[0], v_y[0]); - summs1 = ggml_vec_dot(v_z, q6b[1], v_y[1]); - summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]); - summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]); - - isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] + - (summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] + - (summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] + - (summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3]; - - scale += 4; - } - - sum += d_all * y[i].d * (isum - 32 * mins); - } - - *s = sum; -#else - - int8_t aux8[QK_K]; - int16_t aux16[8]; - float sums [8]; - int32_t aux32[8]; - memset(sums, 0, 8*sizeof(float)); - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const uint8_t * GGML_RESTRICT q4 = x[i].ql; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - memset(aux32, 0, 8*sizeof(int32_t)); - int8_t * GGML_RESTRICT a = aux8; - for (int j = 0; j < QK_K; j += 128) { - for (int l = 0; l < 32; ++l) { - a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; - a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; - a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; - a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; - } - a += 128; - q4 += 64; - qh += 32; - } - a = aux8; - int is = 0; - for (int j = 0; j < QK_K/16; ++j) { - int scale = x[i].scales[is++]; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; - for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; - q8 += 8; a += 8; - } - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; - } - for (int l = 0; l < 8; ++l) sumf += sums[l]; - *s = sumf; -#endif -} - -#if defined (__AVX__) || defined (__AVX2__) || defined (__ARM_NEON) || defined (__POWER9_VECTOR__) || defined(__loongarch_asx) -static const int8_t keven_signs_q2xs[1024] = { - 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, - 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, - 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1, - 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1, - 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, - 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1, - 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1, - 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, - 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1, - 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1, - 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1, - 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1, - 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1, - 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1, - 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1, - 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1, - 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1, - 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1, - 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1, - 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1, - 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1, - 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1, - 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, - 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1, - 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1, - 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1, - 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, - 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1, - 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1, - 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1, - 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1, - 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -}; -#endif - -void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_iq2_xxs * GGML_RESTRICT x = vx; - const block_q8_K * GGML_RESTRICT y = vy; - - const int nb = n / QK_K; - -#if defined(__ARM_NEON) - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - uint32_t aux32[4]; - const uint8_t * aux8 = (const uint8_t *)aux32; - - ggml_int8x16x4_t q2u; - ggml_int8x16x4_t q2s; - ggml_int8x16x4_t q8b; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * GGML_RESTRICT q2 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - float sumf1 = 0, sumf2 = 0; - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - q8b = ggml_vld1q_s8_x4(q8); q8 += 64; - memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; - q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 1]))); - q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 3]))); - q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_s8((const void *)(iq2xxs_grid + aux8[ 9]))); - q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xxs_grid + aux8[10])), vld1_s8((const void *)(iq2xxs_grid + aux8[11]))); - q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127)))); - q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127)))); - q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 7) & 127)))); - q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[3] >> 21) & 127)))); - q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]); - q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]); - q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]); - q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]); - const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]); - const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]); - sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28)); - sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28)); - } - sumf += d*(sumf1 + sumf2); - } - *s = 0.25f * sumf; - -#elif defined(__AVX2__) - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - uint32_t aux32[4]; - const uint8_t * aux8 = (const uint8_t *)aux32; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * GGML_RESTRICT q2 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - __m256i sumi1 = _mm256_setzero_si256(); - __m256i sumi2 = _mm256_setzero_si256(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; - const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); - const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); - const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], - signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); - const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127], - signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); - const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1); - const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2); - const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); - const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); - const uint16_t ls1 = aux32[1] >> 28; - const uint16_t ls2 = aux32[3] >> 28; - const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); - const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); - sumi1 = _mm256_add_epi32(sumi1, p1); - sumi2 = _mm256_add_epi32(sumi2, p2); - } - - accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); - - } - - *s = 0.125f * hsum_float_8(accumf); - -#elif defined(__AVX__) - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - uint32_t aux32[4]; - const uint8_t * aux8 = (const uint8_t *)aux32; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * GGML_RESTRICT q2 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - __m128i sumi1_0 = _mm_setzero_si128(); - __m128i sumi1_1 = _mm_setzero_si128(); - __m128i sumi2_0 = _mm_setzero_si128(); - __m128i sumi2_1 = _mm_setzero_si128(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; - const __m128i q2_1_0 = _mm_set_epi64x(iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); - const __m128i q2_1_1 = _mm_set_epi64x(iq2xxs_grid[aux8[3]], iq2xxs_grid[aux8[2]]); - const __m128i q2_2_0 = _mm_set_epi64x(iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); - const __m128i q2_2_1 = _mm_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]]); - const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); - const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]); - const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); - const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127]); - const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0); - const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1); - const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0); - const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1); - const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); - const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); - const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); - const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); - const uint16_t ls1 = aux32[1] >> 28; - const uint16_t ls2 = aux32[3] >> 28; - const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1)); - const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1)); - const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1)); - const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1)); - sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); - sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); - sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); - sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); - } - - accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); - - } - - *s = 0.125f * hsum_float_8(accumf); - -#elif defined(__POWER9_VECTOR__) - const vector int v0 = vec_splats((int32_t)0); - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - - const uint16_t * GGML_RESTRICT q2 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - for (int j = 0; j < QK_K/32; j += 2) { - __builtin_prefetch(q2, 0, 1); - __builtin_prefetch(q8, 0, 1); - - uint32_t aux32[4]; - const uint8_t * aux8 = (const uint8_t *)aux32; - - memcpy(aux32, q2, 4*sizeof(uint32_t)); - q2 += 8; - - vector signed long long aux64x2_0 = {*(const int64_t *)(iq2xxs_grid + aux8[ 0]), *(const int64_t *)(iq2xxs_grid + aux8[ 1])}; - vector signed long long aux64x2_1 = {*(const int64_t *)(iq2xxs_grid + aux8[ 2]), *(const int64_t *)(iq2xxs_grid + aux8[ 3])}; - vector signed long long aux64x2_2 = {*(const int64_t *)(iq2xxs_grid + aux8[ 8]), *(const int64_t *)(iq2xxs_grid + aux8[ 9])}; - vector signed long long aux64x2_3 = {*(const int64_t *)(iq2xxs_grid + aux8[10]), *(const int64_t *)(iq2xxs_grid + aux8[11])}; - - vector signed long long vsigns0 = {*(const int64_t *)(signs64 + ((aux32[1] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 7) & 127))}; - vector signed long long vsigns1 = {*(const int64_t *)(signs64 + ((aux32[1] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 21) & 127))}; - vector signed long long vsigns2 = {*(const int64_t *)(signs64 + ((aux32[3] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 7) & 127))}; - vector signed long long vsigns3 = {*(const int64_t *)(signs64 + ((aux32[3] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 21) & 127))}; - - vector signed char q2x0 = (vector signed char)vec_mul((vector signed char)vsigns0, (vector signed char)aux64x2_0); - vector signed char q2x1 = (vector signed char)vec_mul((vector signed char)vsigns1, (vector signed char)aux64x2_1); - vector signed char q2x2 = (vector signed char)vec_mul((vector signed char)vsigns2, (vector signed char)aux64x2_2); - vector signed char q2x3 = (vector signed char)vec_mul((vector signed char)vsigns3, (vector signed char)aux64x2_3); - - vector signed char q8y0 = vec_xl( 0, q8); - vector signed char q8y1 = vec_xl(16, q8); - vector signed char q8y2 = vec_xl(32, q8); - vector signed char q8y3 = vec_xl(48, q8); - q8 += 64; - - vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1)); - vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2)); - vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3)); - - const uint16_t ls0 = aux32[1] >> 28; - const uint16_t ls1 = aux32[3] >> 28; - - vector signed short vscales01 = vec_splats((int16_t)(2*ls0+1)); - vector signed short vscales23 = vec_splats((int16_t)(2*ls1+1)); - - vsumi0 = vec_msum(qv0, vscales01, vsumi0); - vsumi1 = vec_msum(qv1, vscales01, vsumi1); - vsumi2 = vec_msum(qv2, vscales23, vsumi2); - vsumi3 = vec_msum(qv3, vscales23, vsumi3); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = 0.125f * vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - uint32_t aux32[4]; - const uint8_t * aux8 = (const uint8_t *)aux32; - - __m256 accumf = (__m256)__lasx_xvldi(0); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * GGML_RESTRICT q2 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - __m256i sumi1 = __lasx_xvldi(0); - __m256i sumi2 = __lasx_xvldi(0); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; - - const __m256i q2_1 = lasx_set_d(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); - const __m256i q2_2 = lasx_set_d(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); - const __m256i s2_1 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], - signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); - const __m256i s2_2 = lasx_set_d(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127], - signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]); - const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1); - const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2); - const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); - const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); - const uint16_t ls1 = aux32[1] >> 28; - const uint16_t ls2 = aux32[3] >> 28; - const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1)); - const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1)); - sumi1 = __lasx_xvadd_w(sumi1, p1); - sumi2 = __lasx_xvadd_w(sumi2, p2); - } - - accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); - } - - *s = 0.125f * hsum_float_8(accumf); -//#elif defined(__VXE__) || defined(__VXE2__) -// const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; -// -// uint32_t aux32[4]; -// const uint8_t * aux8 = (const uint8_t *)aux32; -// -// float sumf = 0; -// -// for (int i = 0; i < nb; ++i) { -// const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; -// const uint16_t * GGML_RESTRICT q2 = x[i].qs; -// const int8_t * GGML_RESTRICT q8 = y[i].qs; -// -// float sumf1 = 0, sumf2 = 0; -// -// for (int ib32 = 0; ib32 < QK_K/32; ib += 2) { -// int8x16_t q8b0 = vec_xl( 0, q8); -// int8x16_t qb81 = vec_xl(16, q8); -// int8x16_t q8b2 = vec_xl(32, q8); -// int8x16_t q8b3 = vec_xl(48, q8); -// q8 += 64; -// -// memcpy(aux32, q2, 4 * sizeof(uint32_t)); -// q2 += 8; -// -// int8x16_t q2u0 = { *(const int64_t *)(iq2xxs_grid + aux8[ 0]), *(const int64_t *)(iq2xxs_grid + aux8[ 1]) }; -// int8x16_t q2u1 = { *(const int64_t *)(iq2xxs_grid + aux8[ 2]), *(const int64_t *)(iq2xxs_grid + aux8[ 3]) }; -// int8x16_t q2u2 = { *(const int64_t *)(iq2xxs_grid + aux8[ 8]), *(const int64_t *)(iq2xxs_grid + aux8[ 9]) }; -// int8x16_t q2u3 = { *(const int64_t *)(iq2xxs_grid + aux8[10]), *(const int64_t *)(iq2xxs_grid + aux8[11]) }; -// -// int8x16_t q2s0 = { *(const int64_t *)(signs64 + ((aux32[1] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 7) & 127)) }; -// int8x16_t q2s1 = { *(const int64_t *)(signs64 + ((aux32[1] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 21) & 127)) }; -// int8x16_t q2s2 = { *(const int64_t *)(signs64 + ((aux32[3] >> 0) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 7) & 127)) }; -// int8x16_t q2s3 = { *(const int64_t *)(signs64 + ((aux32[3] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 21) & 127)) }; -// -// q2u0 = vec_mul(q2u0, q2s0); -// q2u1 = vec_mul(q2u1, q2s1); -// q2u2 = vec_mul(q2u2, q2s2); -// q2u3 = vec_mul(q2u3, q2s3); -// -// const int32x4_t p1 = ggml_vec_dot(ggml_vec_dot(vec_splat_s32(0), q2u0, q8b0), q2u1, q8b1); -// const int32x4_t p2 = ggml_vec_dot(ggml_vec_dot(vec_splat_s32(0), q2u2, q8b2), q2u3, q8b3); -// -// sumf1 += (p1[0] + p1[1] + p1[2] + p1[3]) * (0.5f + (aux32[1] >> 28)); -// sumf2 += (p2[0] + p2[1] + p2[2] + p2[3]) * (0.5f + (aux32[3] >> 28)); -// } -// -// sumf += d * (sumf1 + sumf2); -// } -// -// *s = 0.25f * sumf; -#else - - uint32_t aux32[2]; - const uint8_t * aux8 = (const uint8_t *)aux32; - - float sumf = 0.f; - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * GGML_RESTRICT q2 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - int32_t bsum = 0; - for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { - memcpy(aux32, q2, 2*sizeof(uint32_t)); - q2 += 4; - const uint32_t ls = 2*(aux32[1] >> 28) + 1; - int32_t sumi = 0; - for (int l = 0; l < 4; ++l) { - const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); - const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127]; - for (int j = 0; j < 8; ++j) { - sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); - } - q8 += 8; - } - bsum += sumi * ls; - } - sumf += d * bsum; - } - *s = 0.125f * sumf; -#endif -} - -void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_iq2_xs * GGML_RESTRICT x = vx; - const block_q8_K * GGML_RESTRICT y = vy; - - const int nb = n / QK_K; - -#if defined(__ARM_NEON) - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - ggml_int8x16x4_t q2u; - ggml_int8x16x4_t q2s; - ggml_int8x16x4_t q8b; - - int32x4x4_t scales32; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * GGML_RESTRICT q2 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - const uint8x8_t scales8 = vld1_u8(x[i].scales); - const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf)); - const uint8x8_t scales_h = vshr_n_u8(scales8, 4); - uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h)); - scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1)); - const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales)); - const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales)); - scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1))); - scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1))); - scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2))); - scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2))); - int32x4_t sumi = vdupq_n_s32(0); - for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { - q8b = ggml_vld1q_s8_x4(q8); q8 += 64; - q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511)))); - q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511)))); - q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511)))); - q2u.val[3] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[6] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[7] & 511)))); - q2s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[0] >> 9))), vld1_s8((const void *)(signs64 + (q2[1] >> 9)))); - q2s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[2] >> 9))), vld1_s8((const void *)(signs64 + (q2[3] >> 9)))); - q2s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[4] >> 9))), vld1_s8((const void *)(signs64 + (q2[5] >> 9)))); - q2s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + (q2[6] >> 9))), vld1_s8((const void *)(signs64 + (q2[7] >> 9)))); - q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]); - q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]); - q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]); - q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]); - const int32x4_t p1 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]); - const int32x4_t p2 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]); - const int32x4_t p3 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]); - const int32x4_t p4 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]); - const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4)); - sumi = vmlaq_s32(sumi, p, scales32.val[ib64]); - q2 += 8; - } - sumf += d*vaddvq_s32(sumi); - } - *s = 0.125f * sumf; - -#elif defined(__AVX2__) - - const __m256i mone = _mm256_set1_epi8(1); - static const char block_sign_shuffle_mask_1[32] = { - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, - 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, - }; - static const char block_sign_shuffle_mask_2[32] = { - 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, - 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, - }; - static const uint8_t bit_selector_mask_bytes[32] = { - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - const __m256i bit_selector_mask = _mm256_loadu_si256((const __m256i*)bit_selector_mask_bytes); - const __m256i block_sign_shuffle_1 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_1); - const __m256i block_sign_shuffle_2 = _mm256_loadu_si256((const __m256i*)block_sign_shuffle_mask_2); - - static const uint8_t k_bit_helper[32] = { - 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, - 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, - }; - const __m256i bit_helper = _mm256_loadu_si256((const __m256i*)k_bit_helper); - const __m256i m511 = _mm256_set1_epi16(511); - const __m128i m4 = _mm_set1_epi8(0xf); - const __m128i m1 = _mm_set1_epi8(1); - - uint64_t aux64; - - // somewhat hacky, but gives a significant boost in performance - __m256i aux_gindex; - const uint16_t * gindex = (const uint16_t *)&aux_gindex; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * GGML_RESTRICT q2 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - memcpy(&aux64, x[i].scales, 8); - __m128i stmp = _mm_set1_epi64x(aux64); - stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4)); - const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1); - - __m256i sumi1 = _mm256_setzero_si256(); - __m256i sumi2 = _mm256_setzero_si256(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) { - - const __m256i q2_data = _mm256_loadu_si256((const __m256i*)q2); q2 += 16; - aux_gindex = _mm256_and_si256(q2_data, m511); - - const __m256i partial_sign_bits = _mm256_srli_epi16(q2_data, 9); - const __m256i partial_sign_bits_upper = _mm256_srli_epi16(q2_data, 13); - const __m256i partial_sign_bits_for_counting = _mm256_xor_si256(partial_sign_bits, partial_sign_bits_upper); - - const __m256i odd_bits = _mm256_shuffle_epi8(bit_helper, partial_sign_bits_for_counting); - const __m256i full_sign_bits = _mm256_or_si256(partial_sign_bits, odd_bits); - - const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q8_3 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q8_4 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - - const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]], - iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]); - const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]], - iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]); - const __m256i q2_3 = _mm256_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]], - iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]); - const __m256i q2_4 = _mm256_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]], - iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]); - - const __m128i full_signs_l = _mm256_castsi256_si128(full_sign_bits); - const __m128i full_signs_h = _mm256_extractf128_si256(full_sign_bits, 1); - const __m256i full_signs_1 = MM256_SET_M128I(full_signs_l, full_signs_l); - const __m256i full_signs_2 = MM256_SET_M128I(full_signs_h, full_signs_h); - - __m256i signs; - signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_1); - signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_1 = _mm256_sign_epi8(q8_1, _mm256_or_si256(signs, mone)); - - signs = _mm256_shuffle_epi8(full_signs_1, block_sign_shuffle_2); - signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_2 = _mm256_sign_epi8(q8_2, _mm256_or_si256(signs, mone)); - - signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_1); - signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_3 = _mm256_sign_epi8(q8_3, _mm256_or_si256(signs, mone)); - - signs = _mm256_shuffle_epi8(full_signs_2, block_sign_shuffle_2); - signs = _mm256_cmpeq_epi8(_mm256_and_si256(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_4 = _mm256_sign_epi8(q8_4, _mm256_or_si256(signs, mone)); - - const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); - const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); - const __m256i dot3 = _mm256_maddubs_epi16(q2_3, q8s_3); - const __m256i dot4 = _mm256_maddubs_epi16(q2_4, q8s_4); - - const __m256i sc1 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0))); - const __m256i sc2 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1))); - const __m256i sc3 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2))); - const __m256i sc4 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3))); - - sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot1, sc1)); - sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot2, sc2)); - sumi1 = _mm256_add_epi32(sumi1, _mm256_madd_epi16(dot3, sc3)); - sumi2 = _mm256_add_epi32(sumi2, _mm256_madd_epi16(dot4, sc4)); - } - - accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); - - } - - *s = 0.125f * hsum_float_8(accumf); - -#elif defined(__AVX__) - const __m128i mone = _mm_set1_epi8(1); - static const char block_sign_shuffle_mask_1[32] = { - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, - 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, - }; - static const char block_sign_shuffle_mask_2[32] = { - 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, - 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, - }; - static const uint8_t bit_selector_mask_bytes[32] = { - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - const __m128i bit_selector_mask_0 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes); - const __m128i bit_selector_mask_1 = _mm_loadu_si128((const __m128i*)bit_selector_mask_bytes + 1); - const __m128i block_sign_shuffle_1_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1); - const __m128i block_sign_shuffle_1_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_1 + 1); - const __m128i block_sign_shuffle_2_0 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2); - const __m128i block_sign_shuffle_2_1 = _mm_loadu_si128((const __m128i*)block_sign_shuffle_mask_2 + 1); - - static const uint8_t k_bit_helper[32] = { - 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, - 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, - }; - const __m128i bit_helper_0 = _mm_loadu_si128((const __m128i*)k_bit_helper); - const __m128i bit_helper_1 = _mm_loadu_si128((const __m128i*)k_bit_helper + 1); - const __m128i m511 = _mm_set1_epi16(511); - const __m128i m4 = _mm_set1_epi8(0xf); - const __m128i m1 = _mm_set1_epi8(1); - - uint64_t aux64; - - // somewhat hacky, but gives a significant boost in performance - __m256i aux_gindex; - const uint16_t * gindex = (const uint16_t *)&aux_gindex; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * GGML_RESTRICT q2 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - memcpy(&aux64, x[i].scales, 8); - __m128i stmp = _mm_set1_epi64x(aux64); - stmp = _mm_unpacklo_epi8(_mm_and_si128(stmp, m4), _mm_and_si128(_mm_srli_epi16(stmp, 4), m4)); - const __m128i scales = _mm_add_epi8(_mm_slli_epi16(stmp, 1), m1); - - __m128i sumi1_0 = _mm_setzero_si128(); - __m128i sumi1_1 = _mm_setzero_si128(); - __m128i sumi2_0 = _mm_setzero_si128(); - __m128i sumi2_1 = _mm_setzero_si128(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) { - - const __m128i q2_data_0 = _mm_loadu_si128((const __m128i*)q2); - const __m128i q2_data_1 = _mm_loadu_si128((const __m128i*)q2 + 1); q2 += 16; - aux_gindex = MM256_SET_M128I(_mm_and_si128(q2_data_1, m511), _mm_and_si128(q2_data_0, m511)); - - const __m128i partial_sign_bits_0 = _mm_srli_epi16(q2_data_0, 9); - const __m128i partial_sign_bits_1 = _mm_srli_epi16(q2_data_1, 9); - const __m128i partial_sign_bits_upper_0 = _mm_srli_epi16(q2_data_0, 13); - const __m128i partial_sign_bits_upper_1 = _mm_srli_epi16(q2_data_1, 13); - const __m128i partial_sign_bits_for_counting_0 = _mm_xor_si128(partial_sign_bits_0, partial_sign_bits_upper_0); - const __m128i partial_sign_bits_for_counting_1 = _mm_xor_si128(partial_sign_bits_1, partial_sign_bits_upper_1); - - const __m128i odd_bits_0 = _mm_shuffle_epi8(bit_helper_0, partial_sign_bits_for_counting_0); - const __m128i odd_bits_1 = _mm_shuffle_epi8(bit_helper_1, partial_sign_bits_for_counting_1); - const __m128i full_sign_bits_0 = _mm_or_si128(partial_sign_bits_0, odd_bits_0); - const __m128i full_sign_bits_1 = _mm_or_si128(partial_sign_bits_1, odd_bits_1); - - const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_3_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_3_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_4_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_4_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - - const __m128i q2_1_0 = _mm_set_epi64x(iq2xs_grid[gindex[1]], iq2xs_grid[gindex[0]]); - const __m128i q2_1_1 = _mm_set_epi64x(iq2xs_grid[gindex[3]], iq2xs_grid[gindex[2]]); - const __m128i q2_2_0 = _mm_set_epi64x(iq2xs_grid[gindex[5]], iq2xs_grid[gindex[4]]); - const __m128i q2_2_1 = _mm_set_epi64x(iq2xs_grid[gindex[7]], iq2xs_grid[gindex[6]]); - const __m128i q2_3_0 = _mm_set_epi64x(iq2xs_grid[gindex[9]], iq2xs_grid[gindex[8]]); - const __m128i q2_3_1 = _mm_set_epi64x(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]]); - const __m128i q2_4_0 = _mm_set_epi64x(iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]); - const __m128i q2_4_1 = _mm_set_epi64x(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]]); - - // AVX2 full_signs_1 is full_sign_bits_0 here - // AVX2 full_signs_2 is full_sign_bits_1 here - __m128i signs_0, signs_1; - signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_0); - signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_1_1); - signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); - signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); - const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, _mm_or_si128(signs_0, mone)); - const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, _mm_or_si128(signs_1, mone)); - - signs_0 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_0); - signs_1 = _mm_shuffle_epi8(full_sign_bits_0, block_sign_shuffle_2_1); - signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); - signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); - const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, _mm_or_si128(signs_0, mone)); - const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, _mm_or_si128(signs_1, mone)); - - signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_0); - signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_1_1); - signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); - signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); - const __m128i q8s_3_0 = _mm_sign_epi8(q8_3_0, _mm_or_si128(signs_0, mone)); - const __m128i q8s_3_1 = _mm_sign_epi8(q8_3_1, _mm_or_si128(signs_1, mone)); - - signs_0 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_0); - signs_1 = _mm_shuffle_epi8(full_sign_bits_1, block_sign_shuffle_2_1); - signs_0 = _mm_cmpeq_epi8(_mm_and_si128(signs_0, bit_selector_mask_0), bit_selector_mask_0); - signs_1 = _mm_cmpeq_epi8(_mm_and_si128(signs_1, bit_selector_mask_1), bit_selector_mask_1); - const __m128i q8s_4_0 = _mm_sign_epi8(q8_4_0, _mm_or_si128(signs_0, mone)); - const __m128i q8s_4_1 = _mm_sign_epi8(q8_4_1, _mm_or_si128(signs_1, mone)); - - const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); - const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); - const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); - const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); - const __m128i dot3_0 = _mm_maddubs_epi16(q2_3_0, q8s_3_0); - const __m128i dot3_1 = _mm_maddubs_epi16(q2_3_1, q8s_3_1); - const __m128i dot4_0 = _mm_maddubs_epi16(q2_4_0, q8s_4_0); - const __m128i dot4_1 = _mm_maddubs_epi16(q2_4_1, q8s_4_1); - - __m128i sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+0)); - const __m128i sc1_0 = _mm_cvtepi8_epi16(sc_tmp); - const __m128i sc1_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); - sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+1)); - const __m128i sc2_0 = _mm_cvtepi8_epi16(sc_tmp); - const __m128i sc2_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); - sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+2)); - const __m128i sc3_0 = _mm_cvtepi8_epi16(sc_tmp); - const __m128i sc3_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); - sc_tmp = _mm_shuffle_epi8(scales, get_scale_shuffle(ib32+3)); - const __m128i sc4_0 = _mm_cvtepi8_epi16(sc_tmp); - const __m128i sc4_1 = _mm_cvtepi8_epi16(_mm_srli_si128(sc_tmp, 8)); - - sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot1_0, sc1_0)); - sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot1_1, sc1_1)); - sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot2_0, sc2_0)); - sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot2_1, sc2_1)); - sumi1_0 = _mm_add_epi32(sumi1_0, _mm_madd_epi16(dot3_0, sc3_0)); - sumi1_1 = _mm_add_epi32(sumi1_1, _mm_madd_epi16(dot3_1, sc3_1)); - sumi2_0 = _mm_add_epi32(sumi2_0, _mm_madd_epi16(dot4_0, sc4_0)); - sumi2_1 = _mm_add_epi32(sumi2_1, _mm_madd_epi16(dot4_1, sc4_1)); - } - - accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); - - } - - *s = 0.125f * hsum_float_8(accumf); - -#elif defined(__loongarch_asx) - - const __m256i mone = __lasx_xvreplgr2vr_b(1); - static const char block_sign_shuffle_mask_1[32] = { - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, - 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, - }; - static const char block_sign_shuffle_mask_2[32] = { - 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, 0x0a, - 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, 0x0e, - }; - static const uint8_t bit_selector_mask_bytes[32] = { - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - const __m256i bit_selector_mask = __lasx_xvld((const __m256i*)bit_selector_mask_bytes, 0); - const __m256i block_sign_shuffle_1 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_1, 0); - const __m256i block_sign_shuffle_2 = __lasx_xvld((const __m256i*)block_sign_shuffle_mask_2, 0); - - static const uint8_t k_bit_helper[32] = { - 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, - 0x00, 0x80, 0x80, 0x00, 0x80, 0x00, 0x00, 0x80, 0x80, 0x00, 0x00, 0x80, 0x00, 0x80, 0x80, 0x00, - }; - const __m256i bit_helper = __lasx_xvld((const __m256i*)k_bit_helper, 0); - const __m256i m511 = __lasx_xvreplgr2vr_h(511); - const __m128i m4 = __lsx_vreplgr2vr_b(0xf); - const __m128i m1 = __lsx_vreplgr2vr_b(1); - - uint64_t aux64; - - // somewhat hacky, but gives a significant boost in performance - __m256i aux_gindex; - const uint16_t * gindex = (const uint16_t *)&aux_gindex; - - __m256 accumf = (__m256)__lasx_xvldi(0); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * GGML_RESTRICT q2 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - memcpy(&aux64, x[i].scales, 8); - __m128i stmp = __lsx_vreplgr2vr_d(aux64); - stmp = __lsx_vilvl_b( __lsx_vand_v(__lsx_vsrli_h(stmp, 4), m4), __lsx_vand_v(stmp, m4)); - const __m128i scales = __lsx_vadd_b(__lsx_vslli_h(stmp, 1), m1); - - __m256i sumi1 = __lasx_xvldi(0); - __m256i sumi2 = __lasx_xvldi(0); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 4) { - - const __m256i q2_data = __lasx_xvld((const __m256i*)q2, 0); q2 += 16; - aux_gindex = __lasx_xvand_v(q2_data, m511); - - const __m256i partial_sign_bits = __lasx_xvsrli_h(q2_data, 9); - const __m256i partial_sign_bits_upper = __lasx_xvsrli_h(q2_data, 13); - const __m256i partial_sign_bits_for_counting = __lasx_xvxor_v(partial_sign_bits, partial_sign_bits_upper); - - const __m256i odd_bits = lasx_shuffle_b(bit_helper, partial_sign_bits_for_counting); - const __m256i full_sign_bits = __lasx_xvor_v(partial_sign_bits, odd_bits); - - const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q8_3 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q8_4 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - - const __m256i q2_1 = lasx_set_d(iq2xs_grid[gindex[ 3]], iq2xs_grid[gindex[ 2]], - iq2xs_grid[gindex[ 1]], iq2xs_grid[gindex[ 0]]); - const __m256i q2_2 = lasx_set_d(iq2xs_grid[gindex[ 7]], iq2xs_grid[gindex[ 6]], - iq2xs_grid[gindex[ 5]], iq2xs_grid[gindex[ 4]]); - const __m256i q2_3 = lasx_set_d(iq2xs_grid[gindex[11]], iq2xs_grid[gindex[10]], - iq2xs_grid[gindex[ 9]], iq2xs_grid[gindex[ 8]]); - const __m256i q2_4 = lasx_set_d(iq2xs_grid[gindex[15]], iq2xs_grid[gindex[14]], - iq2xs_grid[gindex[13]], iq2xs_grid[gindex[12]]); - - const __m128i full_signs_l = lasx_extracti128(full_sign_bits, 0); - const __m128i full_signs_h = lasx_extracti128(full_sign_bits, 1); - const __m256i full_signs_1 = lasx_insertf128(full_signs_l, full_signs_l); - const __m256i full_signs_2 = lasx_insertf128(full_signs_h, full_signs_h); - - __m256i signs; - signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_1); - signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_1 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_1); - - signs = lasx_shuffle_b(full_signs_1, block_sign_shuffle_2); - signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_2 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_2); - - signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_1); - signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_3 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_3); - - signs = lasx_shuffle_b(full_signs_2, block_sign_shuffle_2); - signs = __lasx_xvseq_b(__lasx_xvand_v(signs, bit_selector_mask), bit_selector_mask); - const __m256i q8s_4 = __lasx_xvsigncov_b(__lasx_xvor_v(signs, mone), q8_4); - - const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); - const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); - const __m256i dot3 = lasx_maddubs_h(q2_3, q8s_3); - const __m256i dot4 = lasx_maddubs_h(q2_4, q8s_4); - - const __m256i sc1 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+0))); - const __m256i sc2 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+1))); - const __m256i sc3 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+2))); - const __m256i sc4 = lasx_ext8_16(lsx_shuffle_b(scales, get_scale_shuffle(ib32+3))); - - sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot1, sc1)); - sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot2, sc2)); - sumi1 = __lasx_xvadd_w(sumi1, lasx_madd_h(dot3, sc3)); - sumi2 = __lasx_xvadd_w(sumi2, lasx_madd_h(dot4, sc4)); - } - - accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); - - } - - *s = 0.125f * hsum_float_8(accumf); -#elif defined(__POWER9_VECTOR__) - const vector int v0 = vec_splats((int32_t)0); - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - - const uint16_t * GGML_RESTRICT q2 = x[i].qs; - const uint8_t * GGML_RESTRICT sc = x[i].scales; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - for (int j = 0; j < QK_K/64; ++j) { - __builtin_prefetch(q2, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector signed long long aux64x2_0 = {*(const int64_t *)(iq2xs_grid + (q2[0] & 511)), *(const int64_t *)(iq2xs_grid + (q2[1] & 511))}; - vector signed long long aux64x2_1 = {*(const int64_t *)(iq2xs_grid + (q2[2] & 511)), *(const int64_t *)(iq2xs_grid + (q2[3] & 511))}; - vector signed long long aux64x2_2 = {*(const int64_t *)(iq2xs_grid + (q2[4] & 511)), *(const int64_t *)(iq2xs_grid + (q2[5] & 511))}; - vector signed long long aux64x2_3 = {*(const int64_t *)(iq2xs_grid + (q2[6] & 511)), *(const int64_t *)(iq2xs_grid + (q2[7] & 511))}; - - vector signed long long vsigns0 = {*(const int64_t *)(signs64 + ((q2[0] >> 9))), *(const int64_t *)(signs64 + ((q2[1] >> 9)))}; - vector signed long long vsigns1 = {*(const int64_t *)(signs64 + ((q2[2] >> 9))), *(const int64_t *)(signs64 + ((q2[3] >> 9)))}; - vector signed long long vsigns2 = {*(const int64_t *)(signs64 + ((q2[4] >> 9))), *(const int64_t *)(signs64 + ((q2[5] >> 9)))}; - vector signed long long vsigns3 = {*(const int64_t *)(signs64 + ((q2[6] >> 9))), *(const int64_t *)(signs64 + ((q2[7] >> 9)))}; - q2 += 8; - - vector signed char q2x0 = (vector signed char)vec_mul((vector signed char)vsigns0, (vector signed char)aux64x2_0); - vector signed char q2x1 = (vector signed char)vec_mul((vector signed char)vsigns1, (vector signed char)aux64x2_1); - vector signed char q2x2 = (vector signed char)vec_mul((vector signed char)vsigns2, (vector signed char)aux64x2_2); - vector signed char q2x3 = (vector signed char)vec_mul((vector signed char)vsigns3, (vector signed char)aux64x2_3); - - vector signed char q8y0 = vec_xl( 0, q8); - vector signed char q8y1 = vec_xl(16, q8); - vector signed char q8y2 = vec_xl(32, q8); - vector signed char q8y3 = vec_xl(48, q8); - q8 += 64; - - vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1)); - vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2)); - vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3)); - - const uint16_t ls0 = (uint16_t)(sc[0] & 0xf); - const uint16_t ls1 = (uint16_t)(sc[0] >> 4); - const uint16_t ls2 = (uint16_t)(sc[1] & 0xf); - const uint16_t ls3 = (uint16_t)(sc[1] >> 4); - sc += 2; - - vector signed short vscales0 = vec_splats((int16_t)(2*ls0+1)); - vector signed short vscales1 = vec_splats((int16_t)(2*ls1+1)); - vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1)); - vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1)); - - vsumi0 = vec_msum(qv0, vscales0, vsumi0); - vsumi1 = vec_msum(qv1, vscales1, vsumi1); - vsumi2 = vec_msum(qv2, vscales2, vsumi2); - vsumi3 = vec_msum(qv3, vscales3, vsumi3); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = 0.125f * vec_extract(vsumf0, 0); -#else - - float sumf = 0.f; - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint16_t * GGML_RESTRICT q2 = x[i].qs; - const uint8_t * GGML_RESTRICT sc = x[i].scales; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - int32_t bsum = 0; - for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { - const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1; - const uint16_t ls2 = 2*(sc[ib32] >> 4) + 1; - int32_t sumi = 0; - for (int l = 0; l < 2; ++l) { - const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); - const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; - for (int j = 0; j < 8; ++j) { - sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); - } - q8 += 8; - } - bsum += sumi * ls1; - sumi = 0; - for (int l = 2; l < 4; ++l) { - const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); - const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; - for (int j = 0; j < 8; ++j) { - sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); - } - q8 += 8; - } - bsum += sumi * ls2; - q2 += 4; - } - sumf += d * bsum; - } - *s = 0.125f * sumf; -#endif -} - -void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_iq2_s * GGML_RESTRICT x = vx; - const block_q8_K * GGML_RESTRICT y = vy; - - const int nb = n / QK_K; - -#if defined(__ARM_NEON) - - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; - - const ggml_uint8x16x2_t mask1 = ggml_vld1q_u8_x2(k_mask1); - const uint8x16_t mask2 = vld1q_u8(k_mask2); - const uint8x16_t m1 = vdupq_n_u8(1); - const int32x4_t vzero = vdupq_n_s32(0); - - uint8x16x2_t vs; - ggml_int8x16x4_t q2s; - ggml_int8x16x4_t q8b; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - - const uint8_t * GGML_RESTRICT qs = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8); - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - int sumi1 = 0, sumi2 = 0; - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - q8b = ggml_vld1q_s8_x4(q8); q8 += 64; - q2s.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[0] | ((qh[ib32+0] << 8) & 0x300)))), - vld1_s8((const int8_t *)(iq2s_grid + (qs[1] | ((qh[ib32+0] << 6) & 0x300))))); - q2s.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[2] | ((qh[ib32+0] << 4) & 0x300)))), - vld1_s8((const int8_t *)(iq2s_grid + (qs[3] | ((qh[ib32+0] << 2) & 0x300))))); - q2s.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[4] | ((qh[ib32+1] << 8) & 0x300)))), - vld1_s8((const int8_t *)(iq2s_grid + (qs[5] | ((qh[ib32+1] << 6) & 0x300))))); - q2s.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq2s_grid + (qs[6] | ((qh[ib32+1] << 4) & 0x300)))), - vld1_s8((const int8_t *)(iq2s_grid + (qs[7] | ((qh[ib32+1] << 2) & 0x300))))); - qs += 8; - - vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16))); - vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); - vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); - vs.val[0] = vceqq_u8(vs.val[0], mask2); - vs.val[1] = vceqq_u8(vs.val[1], mask2); - - q2s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[0]); - q2s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[1]); - - vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16))); - vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); - vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); - vs.val[0] = vceqq_u8(vs.val[0], mask2); - vs.val[1] = vceqq_u8(vs.val[1], mask2); - - signs += 4; - - q2s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[0], m1)), q2s.val[2]); - q2s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vorrq_u8(vs.val[1], m1)), q2s.val[3]); - - const int32x4_t p1 = ggml_vdotq_s32(vzero, q2s.val[0], q8b.val[0]); - const int32x4_t p2 = ggml_vdotq_s32(vzero, q2s.val[1], q8b.val[1]); - const int32x4_t p3 = ggml_vdotq_s32(vzero, q2s.val[2], q8b.val[2]); - const int32x4_t p4 = ggml_vdotq_s32(vzero, q2s.val[3], q8b.val[3]); - - sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32+0] & 0xf)); - sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32+0] >> 4)); - sumi1 += vaddvq_s32(p3) * (1 + 2*(x[i].scales[ib32+1] & 0xf)); - sumi2 += vaddvq_s32(p4) * (1 + 2*(x[i].scales[ib32+1] >> 4)); - } - sumf += d*(sumi1 + sumi2); - } - - *s = 0.125f * sumf; - -#elif defined(__AVX2__) - - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - const __m128i m4 = _mm_set1_epi8(0xf); - const __m128i m1 = _mm_set1_epi8(1); - - const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1); - const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2); - - uint64_t aux64; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * GGML_RESTRICT qs = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8); - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - memcpy(&aux64, x[i].scales, 8); - const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1); - const __m256i scales16 = _mm256_cvtepi8_epi16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15 - - __m256i sumi1 = _mm256_setzero_si256(); - __m256i sumi2 = _mm256_setzero_si256(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q2_1 = _mm256_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)], - iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)], - iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)], - iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]); - const __m256i q2_2 = _mm256_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)], - iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)], - iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)], - iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]); - qs += 8; - - __m256i aux256 = _mm256_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16)); - aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); - const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2); - const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1); - - aux256 = _mm256_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16)); - aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); - const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2); - const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2); - - signs += 4; - - const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1 - const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3 - - const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+0))); - const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_shuffle_epi8(scales16, get_scale_shuffle_k4(ib32+1))); - sumi1 = _mm256_add_epi32(sumi1, p1); - sumi2 = _mm256_add_epi32(sumi2, p2); - } - - accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); - - } - - *s = 0.125f * hsum_float_8(accumf); - -#elif defined(__AVX__) - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - const __m128i m4 = _mm_set1_epi8(0xf); - const __m128i m1 = _mm_set1_epi8(1); - - const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1); - const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1); - const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2); - const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1); - - uint64_t aux64; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * GGML_RESTRICT qs = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8); - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - memcpy(&aux64, x[i].scales, 8); - const __m128i scales8 = _mm_add_epi8(_mm_slli_epi16(_mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), m4), 1), m1); - const __m128i scales16_0 = _mm_cvtepi8_epi16(scales8); - const __m128i scales16_1 = _mm_cvtepi8_epi16(_mm_srli_si128(scales8, 8)); - - __m128i sumi1_0 = _mm_setzero_si128(); - __m128i sumi1_1 = _mm_setzero_si128(); - __m128i sumi2_0 = _mm_setzero_si128(); - __m128i sumi2_1 = _mm_setzero_si128(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q2_1_0 = _mm_set_epi64x(iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)], - iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]); - const __m128i q2_1_1 = _mm_set_epi64x(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)], - iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)]); - const __m128i q2_2_0 = _mm_set_epi64x(iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)], - iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]); - const __m128i q2_2_1 = _mm_set_epi64x(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)], - iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)]); - qs += 8; - - __m128i aux128_0 = _mm_set1_epi32(signs[0] | ((uint32_t) signs[1] << 16)); - __m128i aux128_1 = aux128_0; - aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); - aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); - const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); - const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); - const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0); - const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1); - - aux128_0 = _mm_set1_epi32(signs[2] | ((uint32_t) signs[3] << 16)); - aux128_1 = aux128_0; - aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); - aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); - const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); - const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); - const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0); - const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1); - - signs += 4; - - const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); - const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); - const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); - const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); - - const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 0))); - const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+0), 1))); - const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_shuffle_epi8(scales16_0, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 0))); - const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_shuffle_epi8(scales16_1, _mm256_extractf128_si256(get_scale_shuffle_k4(ib32+1), 1))); - sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); - sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); - sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); - sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); - } - - accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); - - } - - *s = 0.125f * hsum_float_8(accumf); - -#elif defined(__POWER9_VECTOR__) - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; - - const vector int v0 = vec_splats((int32_t)0); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - const vector unsigned char mask0 = vec_xl( 0, k_mask1); - const vector unsigned char mask1 = vec_xl(16, k_mask1); - const vector signed char mask2 = (vector signed char)vec_xl( 0, k_mask2); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - - const uint8_t * GGML_RESTRICT q2 = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8); - const uint8_t * GGML_RESTRICT sc = x[i].scales; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - for (int j = 0; j < QK_K/32; j += 2) { - __builtin_prefetch(q2, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector signed long long aux64x2_0 = {*(const int64_t *)(iq2s_grid + (q2[0] | ((qh[0] << 8) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[1] | ((qh[0] << 6) & 0x300)))}; - vector signed long long aux64x2_1 = {*(const int64_t *)(iq2s_grid + (q2[2] | ((qh[0] << 4) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[3] | ((qh[0] << 2) & 0x300)))}; - vector signed long long aux64x2_2 = {*(const int64_t *)(iq2s_grid + (q2[4] | ((qh[1] << 8) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[5] | ((qh[1] << 6) & 0x300)))}; - vector signed long long aux64x2_3 = {*(const int64_t *)(iq2s_grid + (q2[6] | ((qh[1] << 4) & 0x300))), *(const int64_t *)(iq2s_grid + (q2[7] | ((qh[1] << 2) & 0x300)))}; - q2 += 8; - qh += 2; - - vector signed char vsigns01 = (vector signed char)vec_splats(*(const uint32_t *)&signs[0]); - vector signed char vsigns23 = (vector signed char)vec_splats(*(const uint32_t *)&signs[2]); - signs += 4; - - vector signed char vsigns0 = vec_perm(vsigns01, vsigns01, mask0); - vector signed char vsigns1 = vec_perm(vsigns01, vsigns01, mask1); - vector signed char vsigns2 = vec_perm(vsigns23, vsigns23, mask0); - vector signed char vsigns3 = vec_perm(vsigns23, vsigns23, mask1); - - vsigns0 = (vector signed char)vec_cmpeq(vec_and(vsigns0, mask2), mask2); - vsigns1 = (vector signed char)vec_cmpeq(vec_and(vsigns1, mask2), mask2); - vsigns2 = (vector signed char)vec_cmpeq(vec_and(vsigns2, mask2), mask2); - vsigns3 = (vector signed char)vec_cmpeq(vec_and(vsigns3, mask2), mask2); - - vector signed char q2x0 = vec_sub(vec_xor(vsigns0, (vector signed char)aux64x2_0), vsigns0); - vector signed char q2x1 = vec_sub(vec_xor(vsigns1, (vector signed char)aux64x2_1), vsigns1); - vector signed char q2x2 = vec_sub(vec_xor(vsigns2, (vector signed char)aux64x2_2), vsigns2); - vector signed char q2x3 = vec_sub(vec_xor(vsigns3, (vector signed char)aux64x2_3), vsigns3); - - vector signed char q8y0 = vec_xl( 0, q8); - vector signed char q8y1 = vec_xl(16, q8); - vector signed char q8y2 = vec_xl(32, q8); - vector signed char q8y3 = vec_xl(48, q8); - q8 += 64; - - vector signed short qv0 = vec_add(vec_mule(q2x0, q8y0), vec_mulo(q2x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q2x1, q8y1), vec_mulo(q2x1, q8y1)); - vector signed short qv2 = vec_add(vec_mule(q2x2, q8y2), vec_mulo(q2x2, q8y2)); - vector signed short qv3 = vec_add(vec_mule(q2x3, q8y3), vec_mulo(q2x3, q8y3)); - - const uint16_t ls0 = (uint16_t)(sc[0] & 0xf); - const uint16_t ls1 = (uint16_t)(sc[0] >> 4); - const uint16_t ls2 = (uint16_t)(sc[1] & 0xf); - const uint16_t ls3 = (uint16_t)(sc[1] >> 4); - sc += 2; - - vector signed short vscales0 = vec_splats((int16_t)(2*ls0+1)); - vector signed short vscales1 = vec_splats((int16_t)(2*ls1+1)); - vector signed short vscales2 = vec_splats((int16_t)(2*ls2+1)); - vector signed short vscales3 = vec_splats((int16_t)(2*ls3+1)); - - vsumi0 = vec_msum(qv0, vscales0, vsumi0); - vsumi1 = vec_msum(qv1, vscales1, vsumi1); - vsumi2 = vec_msum(qv2, vscales2, vsumi2); - vsumi3 = vec_msum(qv3, vscales3, vsumi3); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = 0.125f * vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - - const __m128i m4 = __lsx_vreplgr2vr_b(0xf); - const __m128i m1 = __lsx_vreplgr2vr_b(1); - - const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0); - const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0); - uint64_t aux64; - - __m256 accumf = (__m256)__lasx_xvldi(0); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * GGML_RESTRICT qs = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].qs + QK_K/8); - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - __m128i tmp1; - memcpy(&aux64, x[i].scales, 8); - tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64, 0); - tmp1 = __lsx_vinsgr2vr_d(tmp1, aux64 >> 4, 1); - const __m128i scales8 = __lsx_vadd_b(__lsx_vslli_h(__lsx_vand_v(tmp1, m4), 1), m1); - const __m256i scales16 = lasx_ext8_16(scales8); // 0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15 - - __m256i sumi1 = __lasx_xvldi(0); - __m256i sumi2 = __lasx_xvldi(0); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q2_1 = lasx_set_d(iq2s_grid[qs[3] | ((qh[ib32+0] << 2) & 0x300)], - iq2s_grid[qs[2] | ((qh[ib32+0] << 4) & 0x300)], - iq2s_grid[qs[1] | ((qh[ib32+0] << 6) & 0x300)], - iq2s_grid[qs[0] | ((qh[ib32+0] << 8) & 0x300)]); - const __m256i q2_2 = lasx_set_d(iq2s_grid[qs[7] | ((qh[ib32+1] << 2) & 0x300)], - iq2s_grid[qs[6] | ((qh[ib32+1] << 4) & 0x300)], - iq2s_grid[qs[5] | ((qh[ib32+1] << 6) & 0x300)], - iq2s_grid[qs[4] | ((qh[ib32+1] << 8) & 0x300)]); - qs += 8; - - __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | ((uint32_t) signs[1] << 16)); - aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); - const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2); - const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1); - - aux256 = __lasx_xvreplgr2vr_w(signs[2] | ((uint32_t) signs[3] << 16)); - aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); - const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2); - const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2); - - signs += 4; - - const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); // blocks 2*ib32+0, 2*ib32+1 - const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); // blocks 2*ib32+2, 2*ib32+3 - - const __m256i p1 = lasx_madd_h(dot1, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+0))); - const __m256i p2 = lasx_madd_h(dot2, lasx_shuffle_b(scales16, get_scale_shuffle_k4(ib32+1))); - sumi1 = __lasx_xvadd_w(sumi1, p1); - sumi2 = __lasx_xvadd_w(sumi2, p2); - } - - accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); - } - - *s = 0.125f * hsum_float_8(accumf); - -#else - - float sumf = 0; - for (int i = 0; i < nb; i++) { - - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint8_t * qh = x[i].qh; - const uint8_t * signs = qs + QK_K/8; - - int bsum = 0; - for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { - int ls1 = 1 + 2*(x[i].scales[ib32] & 0xf); - int ls2 = 1 + 2*(x[i].scales[ib32] >> 4); - int sumi1 = 0, sumi2 = 0; - for (int l = 0; l < 2; ++l) { - const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); - for (int j = 0; j < 8; ++j) { - sumi1 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); - } - q8 += 8; - } - for (int l = 2; l < 4; ++l) { - const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); - for (int j = 0; j < 8; ++j) { - sumi2 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); - } - q8 += 8; - } - bsum += ls1 * sumi1 + ls2 * sumi2; - qs += 4; - signs += 4; - } - - sumf += d * bsum; - } - - *s = 0.125f * sumf; - -#endif - -} - -void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_iq3_xxs * GGML_RESTRICT x = vx; - const block_q8_K * GGML_RESTRICT y = vy; - - const int nb = n / QK_K; - -#if defined(__ARM_NEON) - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - uint32_t aux32[2]; - - ggml_int8x16x4_t q3s; - ggml_int8x16x4_t q8b; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * GGML_RESTRICT q3 = x[i].qs; - const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - float sumf1 = 0, sumf2 = 0; - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - q8b = ggml_vld1q_s8_x4(q8); q8 += 64; - memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t); - const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]); - const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]); - const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]); - const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]); - q3 += 16; - q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127)))); - q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127)))); - q3s.val[2] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 7) & 127)))); - q3s.val[3] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[1] >> 21) & 127)))); - q3s.val[0] = vmulq_s8(q3s.val[0], vreinterpretq_s8_u32(aux32x4_0)); - q3s.val[1] = vmulq_s8(q3s.val[1], vreinterpretq_s8_u32(aux32x4_1)); - q3s.val[2] = vmulq_s8(q3s.val[2], vreinterpretq_s8_u32(aux32x4_2)); - q3s.val[3] = vmulq_s8(q3s.val[3], vreinterpretq_s8_u32(aux32x4_3)); - const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]); - const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]); - sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[0] >> 28)); - sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[1] >> 28)); - } - sumf += d*(sumf1 + sumf2); - } - *s = 0.5f * sumf; - -#elif defined(__AVX2__) - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - uint32_t aux32[2]; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * GGML_RESTRICT q3 = x[i].qs; - const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - __m256i sumi1 = _mm256_setzero_si256(); - __m256i sumi2 = _mm256_setzero_si256(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q2_1 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], - iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); - q3 += 8; - const __m256i q2_2 = _mm256_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], - iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); - q3 += 8; - memcpy(aux32, gas, 8); gas += 8; - const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127], - signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]); - const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], - signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); - const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1); - const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2); - const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); - const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); - const uint16_t ls1 = aux32[0] >> 28; - const uint16_t ls2 = aux32[1] >> 28; - const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); - const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); - sumi1 = _mm256_add_epi32(sumi1, p1); - sumi2 = _mm256_add_epi32(sumi2, p2); - } - - accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); - - } - - *s = 0.25f * hsum_float_8(accumf); - -#elif defined(__AVX__) - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - uint32_t aux32[2]; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * GGML_RESTRICT q3 = x[i].qs; - const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - __m128i sumi1_0 = _mm_setzero_si128(); - __m128i sumi1_1 = _mm_setzero_si128(); - __m128i sumi2_0 = _mm_setzero_si128(); - __m128i sumi2_1 = _mm_setzero_si128(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q2_1_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); - const __m128i q2_1_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]); - q3 += 8; - const __m128i q2_2_0 = _mm_set_epi32(iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); - const __m128i q2_2_1 = _mm_set_epi32(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]]); - q3 += 8; - memcpy(aux32, gas, 8); gas += 8; - const __m128i s2_1_0 = _mm_set_epi64x(signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]); - const __m128i s2_1_1 = _mm_set_epi64x(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127]); - const __m128i s2_2_0 = _mm_set_epi64x(signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); - const __m128i s2_2_1 = _mm_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127]); - const __m128i q8s_1_0 = _mm_sign_epi8(q8_1_0, s2_1_0); - const __m128i q8s_1_1 = _mm_sign_epi8(q8_1_1, s2_1_1); - const __m128i q8s_2_0 = _mm_sign_epi8(q8_2_0, s2_2_0); - const __m128i q8s_2_1 = _mm_sign_epi8(q8_2_1, s2_2_1); - const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); - const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); - const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); - const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); - const uint16_t ls1 = aux32[0] >> 28; - const uint16_t ls2 = aux32[1] >> 28; - const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1)); - const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1)); - const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1)); - const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1)); - sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); - sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); - sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); - sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); - } - - accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); - - } - - *s = 0.25f * hsum_float_8(accumf); - -#elif defined(__POWER9_VECTOR__) - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - const vector int v0 = vec_splats((int32_t)0); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - - const uint8_t * GGML_RESTRICT q3 = x[i].qs; - const uint32_t * GGML_RESTRICT signs = (const uint32_t *)(x[i].qs + QK_K/4); - const int8_t * GGML_RESTRICT q8 = y[i].qs; - -#pragma GCC unroll 1 - for (int j = 0; j < QK_K/32; j += 2) { - __builtin_prefetch(q3, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector unsigned int aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]}; - vector unsigned int aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]}; - vector unsigned int aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]}; - vector unsigned int aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]}; - q3 += 16; - - vector unsigned long long aux64x2_0 = {(uint64_t)(signs64[(signs[0] >> 0) & 127]), (uint64_t)(signs64[(signs[0] >> 7) & 127])}; - vector unsigned long long aux64x2_1 = {(uint64_t)(signs64[(signs[0] >> 14) & 127]), (uint64_t)(signs64[(signs[0] >> 21) & 127])}; - vector unsigned long long aux64x2_2 = {(uint64_t)(signs64[(signs[1] >> 0) & 127]), (uint64_t)(signs64[(signs[1] >> 7) & 127])}; - vector unsigned long long aux64x2_3 = {(uint64_t)(signs64[(signs[1] >> 14) & 127]), (uint64_t)(signs64[(signs[1] >> 21) & 127])}; - - vector signed char q3x0 = vec_mul((vector signed char)aux64x2_0, (vector signed char)aux32x4_0); - vector signed char q3x1 = vec_mul((vector signed char)aux64x2_1, (vector signed char)aux32x4_1); - vector signed char q3x2 = vec_mul((vector signed char)aux64x2_2, (vector signed char)aux32x4_2); - vector signed char q3x3 = vec_mul((vector signed char)aux64x2_3, (vector signed char)aux32x4_3); - - vector signed char q8y0 = vec_xl( 0, q8); - vector signed char q8y1 = vec_xl(16, q8); - vector signed char q8y2 = vec_xl(32, q8); - vector signed char q8y3 = vec_xl(48, q8); - q8 += 64; - - vector signed short qv0 = vec_add(vec_mule(q3x0, q8y0), vec_mulo(q3x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q3x1, q8y1), vec_mulo(q3x1, q8y1)); - vector signed short qv2 = vec_add(vec_mule(q3x2, q8y2), vec_mulo(q3x2, q8y2)); - vector signed short qv3 = vec_add(vec_mule(q3x3, q8y3), vec_mulo(q3x3, q8y3)); - - const uint16_t ls0 = (uint16_t)(signs[0] >> 28); - const uint16_t ls1 = (uint16_t)(signs[1] >> 28); - signs += 2; - - vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); - vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); - - vsumi0 = vec_msum(qv0, vscales01, vsumi0); - vsumi1 = vec_msum(qv1, vscales01, vsumi1); - vsumi2 = vec_msum(qv2, vscales23, vsumi2); - vsumi3 = vec_msum(qv3, vscales23, vsumi3); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = 0.25f * vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - - const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; - - uint32_t aux32[2]; - - __m256 accumf = (__m256)__lasx_xvldi(0); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * GGML_RESTRICT q3 = x[i].qs; - const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - __m256i sumi1 = __lasx_xvldi(0); - __m256i sumi2 = __lasx_xvldi(0); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q2_1 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], - iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); - q3 += 8; - const __m256i q2_2 = lasx_set_w(iq3xxs_grid[q3[7]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[4]], - iq3xxs_grid[q3[3]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[0]]); - q3 += 8; - memcpy(aux32, gas, 8); gas += 8; - - const __m256i s2_1 = lasx_set_d(signs64[(aux32[0] >> 21) & 127], signs64[(aux32[0] >> 14) & 127], - signs64[(aux32[0] >> 7) & 127], signs64[(aux32[0] >> 0) & 127]); - const __m256i s2_2 = lasx_set_d(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], - signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]); - const __m256i q8s_1 = __lasx_xvsigncov_b(s2_1, q8_1); - const __m256i q8s_2 = __lasx_xvsigncov_b(s2_2, q8_2); - const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); - const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); - const uint16_t ls1 = aux32[0] >> 28; - const uint16_t ls2 = aux32[1] >> 28; - - const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1)); - const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1)); - sumi1 = __lasx_xvadd_w(sumi1, p1); - sumi2 = __lasx_xvadd_w(sumi2, p2); - } - - accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); - } - - *s = 0.25f * hsum_float_8(accumf); - -#else - - uint32_t aux32; - - float sumf = 0.f; - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * GGML_RESTRICT q3 = x[i].qs; - const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - int32_t bsum = 0; - for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { - memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t); - const uint32_t ls = 2*(aux32 >> 28) + 1; - int32_t sumi = 0; - for (int l = 0; l < 4; ++l) { - const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]); - const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]); - const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127]; - for (int j = 0; j < 4; ++j) { - sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1); - sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1); - } - q8 += 8; - } - q3 += 8; - bsum += sumi * ls; - } - sumf += d * bsum; - } - *s = 0.25f * sumf; -#endif -} - -void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_iq3_s * GGML_RESTRICT x = vx; - const block_q8_K * GGML_RESTRICT y = vy; - - const int nb = n / QK_K; - -#if defined(__ARM_NEON) - - typedef union { - uint16x8_t vec_index; - uint16_t index[8]; - } vec_index_t; - - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; - - static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1}; - - const ggml_uint8x16x2_t mask1 = ggml_vld1q_u8_x2(k_mask1); - const uint8x16_t mask2 = vld1q_u8(k_mask2); - - const int16x8_t hshift = vld1q_s16(k_shift); - const uint16x8_t m256 = vdupq_n_u16(256); - const uint8x16_t m1 = vdupq_n_u8(1); - - uint8x16x2_t vs; - ggml_int8x16x4_t q3s; - ggml_int8x16x4_t q8b; - vec_index_t idx; - - uint32_t scales32[2]; - const uint8_t * scales8 = (const uint8_t *)scales32; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * GGML_RESTRICT qs = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - memcpy(scales32, x[i].scales, 4); - scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101; - scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101; - - int sumi1 = 0, sumi2 = 0; - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - q8b = ggml_vld1q_s8_x4(q8); q8 += 64; - - const uint8x16_t idx_l = vld1q_u8(qs); qs += 16; - idx.vec_index = vorrq_u16(vmovl_u8(vget_low_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+0]), hshift), m256)); - const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]], - iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]); - const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]], - iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]); - idx.vec_index = vorrq_u16(vmovl_u8(vget_high_u8(idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), hshift), m256)); - const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3s_grid[idx.index[0]], iq3s_grid[idx.index[1]], - iq3s_grid[idx.index[2]], iq3s_grid[idx.index[3]]); - const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3s_grid[idx.index[4]], iq3s_grid[idx.index[5]], - iq3s_grid[idx.index[6]], iq3s_grid[idx.index[7]]); - - - vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | ((uint32_t) signs[1] << 16))); - vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); - vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); - vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1); - vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1); - - q3s.val[0] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_0)); - q3s.val[1] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_1)); - - vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[2] | ((uint32_t) signs[3] << 16))); - vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); - vs.val[0] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[0]), mask2); - vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), m1); - vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), m1); - - signs += 4; - - q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), vreinterpretq_s8_u32(aux32x4_2)); - q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), vreinterpretq_s8_u32(aux32x4_3)); - - const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]); - const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]); - - sumi1 += vaddvq_s32(p1) * scales8[ib32/2+0]; - sumi2 += vaddvq_s32(p2) * scales8[ib32/2+4]; - } - sumf += d*(sumi1 + sumi2); - } - *s = sumf; - -#elif defined(__AVX2__) - - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - const __m256i mask1 = _mm256_loadu_si256((const __m256i*)k_mask1); - const __m256i mask2 = _mm256_loadu_si256((const __m256i*)k_mask2); - - const __m256i idx_shift = _mm256_set_epi32(1, 2, 3, 4, 5, 6, 7, 8); - const __m256i idx_mask = _mm256_set1_epi32(256); - - typedef union { - __m256i vec[2]; - uint32_t index[16]; - } index_t; - - index_t idx; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * GGML_RESTRICT qs = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - __m256i sumi1 = _mm256_setzero_si256(); - __m256i sumi2 = _mm256_setzero_si256(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i idx_l = _mm256_cvtepu8_epi16(_mm_loadu_si128((const __m128i *)qs)); qs += 16; - idx.vec[0] = _mm256_set1_epi32(qh[ib32+0]); - idx.vec[1] = _mm256_set1_epi32(qh[ib32+1]); - idx.vec[0] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[0], idx_shift), idx_mask); - idx.vec[1] = _mm256_and_si256(_mm256_sllv_epi32(idx.vec[1], idx_shift), idx_mask); - idx.vec[0] = _mm256_or_si256(idx.vec[0], _mm256_cvtepi16_epi32(_mm256_castsi256_si128(idx_l))); - idx.vec[1] = _mm256_or_si256(idx.vec[1], _mm256_cvtepi16_epi32(_mm256_extractf128_si256(idx_l, 1))); - - // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange. - //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4); - //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4); - const __m256i q2_1 = _mm256_set_epi32( - iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]], - iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]] - ); - const __m256i q2_2 = _mm256_set_epi32( - iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]], - iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]] - ); - - __m256i aux256 = _mm256_set1_epi32(signs[0] | (signs[1] << 16)); - aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); - const __m256i s2_1 = _mm256_cmpeq_epi8(aux256, mask2); - const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(s2_1, q8_1), s2_1); - - aux256 = _mm256_set1_epi32(signs[2] | (signs[3] << 16)); - aux256 = _mm256_and_si256(_mm256_shuffle_epi8(aux256,mask1), mask2); - const __m256i s2_2 = _mm256_cmpeq_epi8(aux256, mask2); - const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(s2_2, q8_2), s2_2); - - signs += 4; - - const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); - const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); - const uint16_t ls1 = x[i].scales[ib32/2] & 0xf; - const uint16_t ls2 = x[i].scales[ib32/2] >> 4; - const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); - const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); - sumi1 = _mm256_add_epi32(sumi1, p1); - sumi2 = _mm256_add_epi32(sumi2, p2); - } - - accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); - - } - - *s = hsum_float_8(accumf); - -#elif defined(__AVX__) - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - const __m128i mask1_0 = _mm_loadu_si128((const __m128i*)k_mask1); - const __m128i mask1_1 = _mm_loadu_si128((const __m128i*)k_mask1 + 1); - const __m128i mask2_0 = _mm_loadu_si128((const __m128i*)k_mask2); - const __m128i mask2_1 = _mm_loadu_si128((const __m128i*)k_mask2 + 1); - - const __m128i idx_mul_0 = _mm_set_epi32(32, 64, 128, 256); - const __m128i idx_mul_1 = _mm_set_epi32(2, 4, 8, 16); - const __m128i idx_mask = _mm_set1_epi32(256); - - typedef union { - __m128i vec[4]; - uint32_t index[16]; - } index_t; - - index_t idx; - - __m256 accumf = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * GGML_RESTRICT qs = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - __m128i sumi1_0 = _mm_setzero_si128(); - __m128i sumi1_1 = _mm_setzero_si128(); - __m128i sumi2_0 = _mm_setzero_si128(); - __m128i sumi2_1 = _mm_setzero_si128(); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m128i q8_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i qs_tmp = _mm_loadu_si128((const __m128i *)qs); - const __m128i idx_l_0 = _mm_cvtepu8_epi16(qs_tmp); - const __m128i idx_l_1 = _mm_cvtepu8_epi16(_mm_srli_si128(qs_tmp, 8)); qs += 16; - idx.vec[0] = _mm_set1_epi32(qh[ib32+0]); - idx.vec[1] = idx.vec[0]; - idx.vec[2] = _mm_set1_epi32(qh[ib32+1]); - idx.vec[3] = idx.vec[2]; - - idx.vec[0] = _mm_and_si128(_mm_mullo_epi32(idx.vec[0], idx_mul_0), idx_mask); - idx.vec[1] = _mm_and_si128(_mm_mullo_epi32(idx.vec[1], idx_mul_1), idx_mask); - idx.vec[2] = _mm_and_si128(_mm_mullo_epi32(idx.vec[2], idx_mul_0), idx_mask); - idx.vec[3] = _mm_and_si128(_mm_mullo_epi32(idx.vec[3], idx_mul_1), idx_mask); - - idx.vec[0] = _mm_or_si128(idx.vec[0], _mm_cvtepi16_epi32(idx_l_0)); - idx.vec[1] = _mm_or_si128(idx.vec[1], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_0, 8))); - idx.vec[2] = _mm_or_si128(idx.vec[2], _mm_cvtepi16_epi32(idx_l_1)); - idx.vec[3] = _mm_or_si128(idx.vec[3], _mm_cvtepi16_epi32(_mm_srli_si128(idx_l_1, 8))); - - const __m128i q2_1_0 = _mm_set_epi32(iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]]); - const __m128i q2_1_1 = _mm_set_epi32(iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]]); - const __m128i q2_2_0 = _mm_set_epi32(iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[9]], iq3s_grid[idx.index[8]]); - const __m128i q2_2_1 = _mm_set_epi32(iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]]); - - __m128i aux128_0 = _mm_set1_epi32(signs[0] | (signs[1] << 16)); - __m128i aux128_1 = aux128_0; - aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); - aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); - const __m128i s2_1_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); - const __m128i s2_1_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); - const __m128i q8s_1_0 = _mm_sub_epi8(_mm_xor_si128(s2_1_0, q8_1_0), s2_1_0); - const __m128i q8s_1_1 = _mm_sub_epi8(_mm_xor_si128(s2_1_1, q8_1_1), s2_1_1); - - aux128_0 = _mm_set1_epi32(signs[2] | (signs[3] << 16)); - aux128_1 = aux128_0; - aux128_0 = _mm_and_si128(_mm_shuffle_epi8(aux128_0,mask1_0), mask2_0); - aux128_1 = _mm_and_si128(_mm_shuffle_epi8(aux128_1,mask1_1), mask2_1); - const __m128i s2_2_0 = _mm_cmpeq_epi8(aux128_0, mask2_0); - const __m128i s2_2_1 = _mm_cmpeq_epi8(aux128_1, mask2_1); - const __m128i q8s_2_0 = _mm_sub_epi8(_mm_xor_si128(s2_2_0, q8_2_0), s2_2_0); - const __m128i q8s_2_1 = _mm_sub_epi8(_mm_xor_si128(s2_2_1, q8_2_1), s2_2_1); - - signs += 4; - - const __m128i dot1_0 = _mm_maddubs_epi16(q2_1_0, q8s_1_0); - const __m128i dot1_1 = _mm_maddubs_epi16(q2_1_1, q8s_1_1); - const __m128i dot2_0 = _mm_maddubs_epi16(q2_2_0, q8s_2_0); - const __m128i dot2_1 = _mm_maddubs_epi16(q2_2_1, q8s_2_1); - const uint16_t ls1 = x[i].scales[ib32/2] & 0xf; - const uint16_t ls2 = x[i].scales[ib32/2] >> 4; - const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(2*ls1+1)); - const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(2*ls1+1)); - const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(2*ls2+1)); - const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(2*ls2+1)); - sumi1_0 = _mm_add_epi32(sumi1_0, p1_0); - sumi1_1 = _mm_add_epi32(sumi1_1, p1_1); - sumi2_0 = _mm_add_epi32(sumi2_0, p2_0); - sumi2_1 = _mm_add_epi32(sumi2_1, p2_1); - } - - accumf = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_add_epi32(sumi1_1, sumi2_1), _mm_add_epi32(sumi1_0, sumi2_0)))), accumf); - - } - - *s = hsum_float_8(accumf); - -#elif defined(__POWER9_VECTOR__) - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; - - const vector int v0 = vec_splats((int32_t)0); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - const vector unsigned char mask0 = vec_xl( 0, k_mask1); - const vector unsigned char mask1 = vec_xl(16, k_mask1); - const vector signed char mask2 = (vector signed char)vec_xl( 0, k_mask2); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - const uint8_t * GGML_RESTRICT q3 = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const uint16_t * GGML_RESTRICT signs = (const uint16_t *)(x[i].signs); - const uint8_t * GGML_RESTRICT sc = x[i].scales; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - - for (int j = 0; j < QK_K/32; j += 2) { - __builtin_prefetch(q3, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector unsigned int aux32x4_0 = {iq3s_grid[q3[ 0] | ((qh[0] << 8) & 256)], iq3s_grid[q3[ 1] | ((qh[0] << 7) & 256)], - iq3s_grid[q3[ 2] | ((qh[0] << 6) & 256)], iq3s_grid[q3[ 3] | ((qh[0] << 5) & 256)]}; - vector unsigned int aux32x4_1 = {iq3s_grid[q3[ 4] | ((qh[0] << 4) & 256)], iq3s_grid[q3[ 5] | ((qh[0] << 3) & 256)], - iq3s_grid[q3[ 6] | ((qh[0] << 2) & 256)], iq3s_grid[q3[ 7] | ((qh[0] << 1) & 256)]}; - vector unsigned int aux32x4_2 = {iq3s_grid[q3[ 8] | ((qh[1] << 8) & 256)], iq3s_grid[q3[ 9] | ((qh[1] << 7) & 256)], - iq3s_grid[q3[10] | ((qh[1] << 6) & 256)], iq3s_grid[q3[11] | ((qh[1] << 5) & 256)]}; - vector unsigned int aux32x4_3 = {iq3s_grid[q3[12] | ((qh[1] << 4) & 256)], iq3s_grid[q3[13] | ((qh[1] << 3) & 256)], - iq3s_grid[q3[14] | ((qh[1] << 2) & 256)], iq3s_grid[q3[15] | ((qh[1] << 1) & 256)]}; - q3 += 16; - qh += 2; - - vector signed char vsigns01 = (vector signed char)vec_splats(*(const uint32_t *)&signs[0]); - vector signed char vsigns02 = (vector signed char)vec_splats(*(const uint32_t *)&signs[2]); - signs += 4; - - vector signed char vsigns0 = vec_perm(vsigns01, vsigns01, mask0); - vector signed char vsigns1 = vec_perm(vsigns01, vsigns01, mask1); - vector signed char vsigns2 = vec_perm(vsigns02, vsigns02, mask0); - vector signed char vsigns3 = vec_perm(vsigns02, vsigns02, mask1); - - vsigns0 = (vector signed char)vec_cmpeq(vec_and(vsigns0, mask2), mask2); - vsigns1 = (vector signed char)vec_cmpeq(vec_and(vsigns1, mask2), mask2); - vsigns2 = (vector signed char)vec_cmpeq(vec_and(vsigns2, mask2), mask2); - vsigns3 = (vector signed char)vec_cmpeq(vec_and(vsigns3, mask2), mask2); - - vector signed char q3x0 = vec_sub(vec_xor(vsigns0, (vector signed char)aux32x4_0), vsigns0); - vector signed char q3x1 = vec_sub(vec_xor(vsigns1, (vector signed char)aux32x4_1), vsigns1); - vector signed char q3x2 = vec_sub(vec_xor(vsigns2, (vector signed char)aux32x4_2), vsigns2); - vector signed char q3x3 = vec_sub(vec_xor(vsigns3, (vector signed char)aux32x4_3), vsigns3); - - vector signed char q8y0 = vec_xl( 0, q8); - vector signed char q8y1 = vec_xl(16, q8); - vector signed char q8y2 = vec_xl(32, q8); - vector signed char q8y3 = vec_xl(48, q8); - q8 += 64; - - vector signed short qv0 = vec_add(vec_mule(q3x0, q8y0), vec_mulo(q3x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q3x1, q8y1), vec_mulo(q3x1, q8y1)); - vector signed short qv2 = vec_add(vec_mule(q3x2, q8y2), vec_mulo(q3x2, q8y2)); - vector signed short qv3 = vec_add(vec_mule(q3x3, q8y3), vec_mulo(q3x3, q8y3)); - - const uint16_t ls0 = (uint16_t)(sc[0] & 0xf); - const uint16_t ls1 = (uint16_t)(sc[0] >> 4); - sc ++; - - vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); - vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); - - vsumi0 = vec_msum(qv0, vscales01, vsumi0); - vsumi1 = vec_msum(qv1, vscales01, vsumi1); - vsumi2 = vec_msum(qv2, vscales23, vsumi2); - vsumi3 = vec_msum(qv3, vscales23, vsumi3); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - - static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, - 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 - }; - - static const uint8_t k_mask2[32] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, - }; - - const __m256i mask1 = __lasx_xvld((const __m256i*)k_mask1, 0); - const __m256i mask2 = __lasx_xvld((const __m256i*)k_mask2, 0); - - __m256i idx_shift = lasx_set_w(1, 2, 3, 4, 5, 6, 7, 8); - const __m256i idx_mask = __lasx_xvreplgr2vr_w(256); - - typedef union { - __m256i vec[2]; - uint32_t index[16]; - } index_t; - - index_t idx; - - __m256 accumf = (__m256)__lasx_xvldi(0); - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * GGML_RESTRICT qs = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const uint16_t * GGML_RESTRICT signs = (const uint16_t *)x[i].signs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - __m256i sumi1 = __lasx_xvldi(0); - __m256i sumi2 = __lasx_xvldi(0); - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const __m256i q8_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q8_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i idx_l = lasx_extu8_16(__lsx_vld(qs, 0)); qs += 16; - idx.vec[0] = __lasx_xvreplgr2vr_w(qh[ib32+0]); - idx.vec[1] = __lasx_xvreplgr2vr_w(qh[ib32+1]); - idx.vec[0] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[0], idx_shift), idx_mask); - idx.vec[1] = __lasx_xvand_v(__lasx_xvsll_w(idx.vec[1], idx_shift), idx_mask); - idx.vec[0] = __lasx_xvor_v(idx.vec[0], lasx_ext16_32(lasx_extracti128(idx_l, 0))); - idx.vec[1] = __lasx_xvor_v(idx.vec[1], lasx_ext16_32(lasx_extracti128(idx_l, 1))); - - // At leat on my CPU (Ryzen 7950X), using _mm256_i32gather_epi32 is slower than _mm256_set_epi32. Strange. - //const __m256i q2_1 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[0], 4); - //const __m256i q2_2 = _mm256_i32gather_epi32((const int *)iq3s_grid, idx.vec[1], 4); - const __m256i q2_1 = lasx_set_w( - iq3s_grid[idx.index[7]], iq3s_grid[idx.index[6]], iq3s_grid[idx.index[5]], iq3s_grid[idx.index[4]], - iq3s_grid[idx.index[3]], iq3s_grid[idx.index[2]], iq3s_grid[idx.index[1]], iq3s_grid[idx.index[0]] - ); - const __m256i q2_2 = lasx_set_w( - iq3s_grid[idx.index[15]], iq3s_grid[idx.index[14]], iq3s_grid[idx.index[13]], iq3s_grid[idx.index[12]], - iq3s_grid[idx.index[11]], iq3s_grid[idx.index[10]], iq3s_grid[idx.index[ 9]], iq3s_grid[idx.index[ 8]] - ); - - __m256i aux256 = __lasx_xvreplgr2vr_w(signs[0] | (signs[1] << 16)); - aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); - const __m256i s2_1 = __lasx_xvseq_b(aux256, mask2); - const __m256i q8s_1 = __lasx_xvsub_b(__lasx_xvxor_v(s2_1, q8_1), s2_1); - - aux256 = __lasx_xvreplgr2vr_w(signs[2] | (signs[3] << 16)); - aux256 = __lasx_xvand_v(lasx_shuffle_b(aux256,mask1), mask2); - const __m256i s2_2 = __lasx_xvseq_b(aux256, mask2); - const __m256i q8s_2 = __lasx_xvsub_b(__lasx_xvxor_v(s2_2, q8_2), s2_2); - - signs += 4; - - const __m256i dot1 = lasx_maddubs_h(q2_1, q8s_1); - const __m256i dot2 = lasx_maddubs_h(q2_2, q8s_2); - const uint16_t ls1 = x[i].scales[ib32/2] & 0xf; - const uint16_t ls2 = x[i].scales[ib32/2] >> 4; - const __m256i p1 = lasx_madd_h(dot1, __lasx_xvreplgr2vr_h(2*ls1+1)); - const __m256i p2 = lasx_madd_h(dot2, __lasx_xvreplgr2vr_h(2*ls2+1)); - sumi1 = __lasx_xvadd_w(sumi1, p1); - sumi2 = __lasx_xvadd_w(sumi2, p2); - } - - accumf = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accumf); - } - - *s = hsum_float_8(accumf); - -#else - - float sumf = 0.f; - for (int i = 0; i < nb; ++i) { - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - const uint8_t * GGML_RESTRICT qs = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const uint8_t * GGML_RESTRICT signs = x[i].signs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - int32_t bsum = 0; - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { - const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1; - const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1; - int32_t sumi = 0; - for (int l = 0; l < 4; ++l) { - const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256))); - for (int j = 0; j < 4; ++j) { - sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); - sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); - } - q8 += 8; - } - qs += 8; - signs += 4; - bsum += sumi * ls1; - sumi = 0; - for (int l = 0; l < 4; ++l) { - const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256))); - const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256))); - for (int j = 0; j < 4; ++j) { - sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); - sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); - } - q8 += 8; - } - qs += 8; - signs += 4; - bsum += sumi * ls2; - } - sumf += d * bsum; - } - *s = sumf; -#endif -} - -#if defined(__AVX2__) -static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { - const __m256i ax = _mm256_sign_epi8(x, x); - const __m256i sy = _mm256_sign_epi8(y, x); - return _mm256_maddubs_epi16(ax, sy); -} -#elif defined(__loongarch_asx) -static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { - const __m256i a = __lasx_xvmulwev_h_b(x, y); - const __m256i b = __lasx_xvmulwod_h_b(x, y); - return __lasx_xvadd_h(a, b); -} -#endif - -void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_iq1_s * GGML_RESTRICT x = vx; - const block_q8_K * GGML_RESTRICT y = vy; - - const int nb = n / QK_K; - -#if defined __ARM_NEON - - ggml_int8x16x4_t q1b; - ggml_int8x16x4_t q8b; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint16_t * qh = x[i].qh; - - int sumi1 = 0, sumi2 = 0, sumi3 = 0; - - for (int ib = 0; ib < QK_K/32; ib += 2) { - - q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[ib+0] << 8) & 0x700)))), - vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[ib+0] << 5) & 0x700))))); - q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[ib+0] << 2) & 0x700)))), - vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[ib+0] >> 1) & 0x700))))); - q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[ib+1] << 8) & 0x700)))), - vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[ib+1] << 5) & 0x700))))); - q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[ib+1] << 2) & 0x700)))), - vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[ib+1] >> 1) & 0x700))))); - qs += 8; - - q8b = ggml_vld1q_s8_x4(q8); q8 += 64; - - const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[0], q8b.val[0]), q1b.val[1], q8b.val[1]); - const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[2], q8b.val[2]), q1b.val[3], q8b.val[3]); - - const int ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; - const int ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; - sumi1 += vaddvq_s32(p1) * ls1; - sumi2 += vaddvq_s32(p2) * ls2; - sumi3 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * ls1 * (qh[ib+0] & 0x8000 ? -1 : 1) - + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * ls2 * (qh[ib+1] & 0x8000 ? -1 : 1); - - } - - sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * (sumi1 + sumi2 + IQ1S_DELTA * sumi3); - } - - *s = sumf; - -#elif defined __AVX2__ - - __m256 accum = _mm256_setzero_ps(); - float accum1 = 0; - for (int i = 0; i < nb; ++i) { - - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint16_t * qh = x[i].qh; - - __m256i sumi = _mm256_setzero_si256(); - int sumi1 = 0; - for (int ib = 0; ib < QK_K/32; ib += 2) { -#ifdef __BMI2__ - const uint64_t packed_idx1 = _pdep_u64(*(const uint32_t *)qs, 0x00ff00ff00ff00ffULL) | _pdep_u64(qh[ib], 0x700070007000700ULL); - const uint64_t packed_idx2 = _pdep_u64(*(const uint32_t *)(qs + 4), 0x00ff00ff00ff00ffULL) | _pdep_u64(qh[ib + 1], 0x700070007000700ULL); - const uint16_t *idx1 = (const uint16_t *)(&packed_idx1); - const uint16_t *idx2 = (const uint16_t *)(&packed_idx2); - const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[idx1[3]], iq1s_grid[idx1[2]], iq1s_grid[idx1[1]], iq1s_grid[idx1[0]]); - const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[idx2[3]], iq1s_grid[idx2[2]], iq1s_grid[idx2[1]], iq1s_grid[idx2[0]]); -#else - const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)], - iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]); - const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)], - iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]); -#endif - qs += 8; - const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - - const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1); - const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2); - const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; - const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; - const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(ls1)); - const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(ls2)); - - sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p1, p2)); - sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1 - + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2; - } - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - accum = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(sumi), accum); - accum1 += d * sumi1; - - } - - *s = hsum_float_8(accum) + IQ1S_DELTA * accum1; - -#elif defined __AVX__ - __m256 accum = _mm256_setzero_ps(); - float accum1 = 0; - for (int i = 0; i < nb; ++i) { - - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint16_t * qh = x[i].qh; - - __m128i sumi1_0 = _mm_setzero_si128(); - __m128i sumi1_1 = _mm_setzero_si128(); - int sumi1 = 0; - for (int ib = 0; ib < QK_K/32; ib += 2) { - const __m128i q1b_1_0 = _mm_set_epi64x(iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)]); - const __m128i q1b_1_1 = _mm_set_epi64x(iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)]); - const __m128i q1b_2_0 = _mm_set_epi64x(iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)]); - const __m128i q1b_2_1 = _mm_set_epi64x(iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)]); - qs += 8; - const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - - const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0); - const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1); - const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0); - const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1); - const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; - const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; - const __m128i p1_0 = _mm_madd_epi16(dot1_0, _mm_set1_epi16(ls1)); - const __m128i p1_1 = _mm_madd_epi16(dot1_1, _mm_set1_epi16(ls1)); - const __m128i p2_0 = _mm_madd_epi16(dot2_0, _mm_set1_epi16(ls2)); - const __m128i p2_1 = _mm_madd_epi16(dot2_1, _mm_set1_epi16(ls2)); - - sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0)); - sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1)); - sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1 - + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2; - } - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum); - accum1 += d * sumi1; - - } - - *s = hsum_float_8(accum) + IQ1S_DELTA * accum1; - -#elif defined(__POWER9_VECTOR__) - const vector unsigned char v0 = vec_splats((unsigned char)0x0); - const vector unsigned short vsign = vec_splats((unsigned short)0x8000); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - for (int i = 0; i < nb; ++i) { - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[i].d)); - vector float vyd = vec_splats(y[i].d); - vector float vd = vec_mul(vxd, vyd); - - vector signed int vsumi0 = vec_splats((int32_t)0); - vector signed int vsumi1 = vec_splats((int32_t)0); - vector signed int vsumi2 = vec_splats((int32_t)0); - vector signed int vsumi3 = vec_splats((int32_t)0); - vector signed int vsumi8 = vec_splats((int32_t)0); - - const uint8_t * GGML_RESTRICT q1 = x[i].qs; - const uint16_t * GGML_RESTRICT qh = x[i].qh; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - const int16_t * GGML_RESTRICT qs = y[i].bsums; - - for (int j = 0; j < QK_K/32; j += 2) { - __builtin_prefetch(q1, 0, 1); - __builtin_prefetch(qh, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector signed long long aux64x2_0 = {*(const int64_t *)(iq1s_grid + (q1[0] | ((qh[0] << 8) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[1] | ((qh[0] << 5) & 0x700)))}; - vector signed long long aux64x2_1 = {*(const int64_t *)(iq1s_grid + (q1[2] | ((qh[0] << 2) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[3] | ((qh[0] >> 1) & 0x700)))}; - vector signed long long aux64x2_2 = {*(const int64_t *)(iq1s_grid + (q1[4] | ((qh[1] << 8) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[5] | ((qh[1] << 5) & 0x700)))}; - vector signed long long aux64x2_3 = {*(const int64_t *)(iq1s_grid + (q1[6] | ((qh[1] << 2) & 0x700))), *(const int64_t *)(iq1s_grid + (q1[7] | ((qh[1] >> 1) & 0x700)))}; - q1 += 8; - - vector signed char q1x0 = (vector signed char)aux64x2_0; - vector signed char q1x1 = (vector signed char)aux64x2_1; - vector signed char q1x2 = (vector signed char)aux64x2_2; - vector signed char q1x3 = (vector signed char)aux64x2_3; - - vector signed char q8y0 = vec_xl( 0, q8); - vector signed char q8y1 = vec_xl(16, q8); - vector signed char q8y2 = vec_xl(32, q8); - vector signed char q8y3 = vec_xl(48, q8); - q8 += 64; - - vector signed short qv0 = vec_add(vec_mule(q1x0, q8y0), vec_mulo(q1x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q1x1, q8y1), vec_mulo(q1x1, q8y1)); - vector signed short qv2 = vec_add(vec_mule(q1x2, q8y2), vec_mulo(q1x2, q8y2)); - vector signed short qv3 = vec_add(vec_mule(q1x3, q8y3), vec_mulo(q1x3, q8y3)); - - const uint16_t ls0 = (uint16_t)((qh[0] >> 12) & 7); - const uint16_t ls1 = (uint16_t)((qh[1] >> 12) & 7); - - vector signed short vscales01 = (vector signed short)vec_splats((uint16_t)(2*ls0+1)); - vector signed short vscales23 = (vector signed short)vec_splats((uint16_t)(2*ls1+1)); - vector signed short vscales = vec_sld(vscales23, vscales01, 8); - - vsumi0 = vec_msum(qv0, vscales01, vsumi0); - vsumi1 = vec_msum(qv1, vscales01, vsumi1); - vsumi2 = vec_msum(qv2, vscales23, vsumi2); - vsumi3 = vec_msum(qv3, vscales23, vsumi3); - - vector signed short q8ysums = vec_xl_len(qs, 8); - qs += 4; - q8ysums = vec_mergeh(q8ysums, (vector signed short)v0); - - vector signed short qxh = (vector signed short)vec_sld(vec_splats(qh[1]), vec_splats(qh[0]), 8); - qh += 2; - vector __bool short vsel = vec_cmpge(qxh, (vector signed short)v0); - - vector signed short q8ysum = vec_sel((vector signed short)vec_xor((vector unsigned short)q8ysums, vsign), q8ysums, vsel); - - vsumi8 = vec_add(vec_mule(q8ysum, vscales), vsumi8); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - - vsumf0 = vec_madd(vec_ctf(vsumi8, 0), vec_mul(vd, vec_splats(IQ1S_DELTA)), vsumf0); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - - __m256 accum = (__m256)__lasx_xvldi(0); - float accum1 = 0; - for (int i = 0; i < nb; ++i) { - - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint16_t * qh = x[i].qh; - - __m256i sumi = __lasx_xvldi(0); - int sumi1 = 0; - for (int ib = 0; ib < QK_K/32; ib += 2) { - __m256i q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[0] | ((qh[ib+0] << 8) & 0x700)], 0); - q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[1] | ((qh[ib+0] << 5) & 0x700)], 1); - q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[2] | ((qh[ib+0] << 2) & 0x700)], 2); - q1b_1 = __lasx_xvinsgr2vr_d(q1b_1, iq1s_grid[qs[3] | ((qh[ib+0] >> 1) & 0x700)], 3); - - __m256i q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[4] | ((qh[ib+1] << 8) & 0x700)], 0); - q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[5] | ((qh[ib+1] << 5) & 0x700)], 1); - q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[6] | ((qh[ib+1] << 2) & 0x700)], 2); - q1b_2 = __lasx_xvinsgr2vr_d(q1b_2, iq1s_grid[qs[7] | ((qh[ib+1] >> 1) & 0x700)], 3); - - qs += 8; - const __m256i q8b_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - const __m256i q8b_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32; - - const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1); - const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2); - const int16_t ls1 = 2*((qh[ib+0] >> 12) & 7) + 1; - const int16_t ls2 = 2*((qh[ib+1] >> 12) & 7) + 1; - - __m256i tmp1, tmp5, tmp6; - tmp1 = __lasx_xvreplgr2vr_h(ls1); - tmp5 = __lasx_xvmulwev_w_h(dot1, tmp1); - tmp6 = __lasx_xvmulwod_w_h(dot1, tmp1); - const __m256i p1 = __lasx_xvadd_w(tmp5, tmp6); - - tmp1 = __lasx_xvreplgr2vr_h(ls2); - tmp5 = __lasx_xvmulwev_w_h(dot2, tmp1); - tmp6 = __lasx_xvmulwod_w_h(dot2, tmp1); - const __m256i p2 = __lasx_xvadd_w(tmp5, tmp6); - - sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p1, p2)); - sumi1 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * (qh[ib+0] & 0x8000 ? -1 : 1) * ls1 - + (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * (qh[ib+1] & 0x8000 ? -1 : 1) * ls2; - } - - const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); - accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), accum); - accum1 += d * sumi1; - } - - *s = hsum_float_8(accum) + IQ1S_DELTA * accum1; - -#else - - float sumf = 0; - for (int i = 0; i < nb; i++) { - - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint16_t * qh = x[i].qh; - - int sumi = 0, sumi1 = 0; - for (int ib = 0; ib < QK_K/32; ++ib) { - const int ls = 2*((qh[ib] >> 12) & 7) + 1; - const int delta = qh[ib] & 0x8000 ? -1 : 1; - int lsum = 0; - for (int l = 0; l < 4; ++l) { - const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8))); - for (int j = 0; j < 8; ++j) { - lsum += q8[j] * grid[j]; - } - q8 += 8; - } - sumi += ls * lsum; - sumi1 += ls * delta * (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]); - qs += 4; - } - - sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); - } - - *s = sumf; - -#endif -} - -void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(n % QK_K == 0); - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_iq1_m * GGML_RESTRICT x = vx; - const block_q8_K * GGML_RESTRICT y = vy; - - const int nb = n / QK_K; - - iq1m_scale_t scale; - -#if defined __ARM_NEON - const int32x4_t mask = vdupq_n_s32(0x7); - const int32x4_t mone = vdupq_n_s32(1); - const int32x4_t mzero = vdupq_n_s32(0); - - ggml_int8x16x4_t deltas; - deltas.val[0] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(+1)); - deltas.val[1] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(+1)); - deltas.val[2] = vcombine_s8(vdup_n_s8(+1), vdup_n_s8(-1)); - deltas.val[3] = vcombine_s8(vdup_n_s8(-1), vdup_n_s8(-1)); - - ggml_int8x16x4_t q1b; - ggml_int8x16x4_t q8b; - - uint32_t aux32; - const uint8_t * aux8 = (const uint8_t *)&aux32; - - float sumf = 0; - for (int i = 0; i < nb; ++i) { - - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint8_t * qh = x[i].qh; - const uint16_t * sc = (const uint16_t *)x[i].scales; - - scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); - - int32x4_t sumi1 = mzero; - int32x4_t sumi2 = mzero; - - for (int ib = 0; ib < QK_K/32; ib += 2) { - - q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700)))), - vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 4) & 0x700))))); - q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[1] << 8) & 0x700)))), - vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[1] << 4) & 0x700))))); - q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[2] << 8) & 0x700)))), - vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[2] << 4) & 0x700))))); - q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[3] << 8) & 0x700)))), - vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[3] << 4) & 0x700))))); - - q8b = ggml_vld1q_s8_x4(q8); q8 += 64; - - const int32x4_t p1 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(mzero, q1b.val[1], q8b.val[1])); - const int32x4_t p2 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(mzero, q1b.val[3], q8b.val[3])); - const int32x4_t p12 = vpaddq_s32(p1, p2); - - const uint32_t * qh32 = (const uint32_t *)qh; // we are 4-byte aligned, so we can do that - aux32 = ((qh32[0] >> 3) & 0x01010101) | ((qh32[0] >> 6) & 0x02020202); - - const int32x4_t p3 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[0]], q8b.val[0]), ggml_vdotq_s32(mzero, deltas.val[aux8[1]], q8b.val[1])); - const int32x4_t p4 = vpaddq_s32(ggml_vdotq_s32(mzero, deltas.val[aux8[2]], q8b.val[2]), ggml_vdotq_s32(mzero, deltas.val[aux8[3]], q8b.val[3])); - const int32x4_t p34 = vpaddq_s32(p3, p4); - - int32x4_t scales_4 = ggml_vld1q_u32(sc[ib/2] >> 0, sc[ib/2] >> 3, sc[ib/2] >> 6, sc[ib/2] >> 9); - - scales_4 = vaddq_s32(vshlq_n_s32(vandq_s32(scales_4, mask), 1), mone); - - sumi1 = vmlaq_s32(sumi1, scales_4, p12); - sumi2 = vmlaq_s32(sumi2, scales_4, p34); - - qs += 8; qh += 4; - - } - - sumf += y[i].d * GGML_FP16_TO_FP32(scale.f16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2)); - } - - *s = sumf; - -#elif defined __AVX2__ - - const __m256i mask = _mm256_set1_epi16(0x7); - const __m256i mone = _mm256_set1_epi16(1); - const __m256i mone8 = _mm256_set1_epi8(1); - const __m256i mtwo8 = _mm256_set1_epi8(2); - // VPSHUFB cannot cross 128-bit lanes so odd shifts go to upper half. - const __m256i scales_shift = _mm256_set_epi64x(9, 3, 6, 0); - - __m256 accum1 = _mm256_setzero_ps(); - __m256 accum2 = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint8_t * qh = x[i].qh; - const uint16_t * sc = (const uint16_t *)x[i].scales; - - scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); - // Extract 3-bit scales (16 values) - __m256i scales = _mm256_set1_epi64x(*(const uint64_t*)sc); - scales = _mm256_srlv_epi64(scales, scales_shift); - scales = _mm256_add_epi16(_mm256_slli_epi16(_mm256_and_si256(scales, mask), 1), mone); - - // Indices to repeat each scale 8 times. - __m256i scales_idx1 = _mm256_set1_epi16(0x0100); - __m256i scales_idx2 = _mm256_add_epi8(scales_idx1, _mm256_set1_epi8(8)); - - __m256i sumi1 = _mm256_setzero_si256(); - __m256i sumi2 = _mm256_setzero_si256(); - for (int ib = 0; ib < QK_K/32; ib += 2) { -#ifdef __BMI2__ - const uint64_t packed_idx1 = _pdep_u64(*(const uint32_t *)qs, 0x00ff00ff00ff00ffULL) - | _pdep_u64(*(const uint16_t*)(qh) & 0x7777, 0xf000f000f000f00ULL); - const uint64_t packed_idx2 = _pdep_u64(*(const uint32_t *)(qs + 4), 0x00ff00ff00ff00ffULL) - | _pdep_u64(*(const uint16_t*)(qh + 2) & 0x7777, 0xf000f000f000f00ULL); - const uint16_t *idx1 = (const uint16_t *)(&packed_idx1); - const uint16_t *idx2 = (const uint16_t *)(&packed_idx2); - const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[idx1[3]], iq1s_grid[idx1[2]], iq1s_grid[idx1[1]], iq1s_grid[idx1[0]]); - const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[idx2[3]], iq1s_grid[idx2[2]], iq1s_grid[idx2[1]], iq1s_grid[idx2[0]]); - - // Convert signs to bytes 0x81 (negative) or 0x01 (positive) - const uint64_t delta_sign = _pdep_u64(*(const uint32_t*)(qh) & 0x88888888, 0xf0f0f0f0f0f0f0f0ULL); - const __m256i delta1 = _mm256_or_si256(mone8, _mm256_cvtepi8_epi64(_mm_set1_epi32(delta_sign))); - const __m256i delta2 = _mm256_or_si256(mone8, _mm256_cvtepi8_epi64(_mm_set1_epi32(delta_sign >> 32))); -#else - const __m256i q1b_1 = _mm256_set_epi64x( - iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)], - iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)] - ); - const __m256i q1b_2 = _mm256_set_epi64x( - iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)], - iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)] - ); - - const __m256i delta1 = _mm256_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, - qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101, - qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, - qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); - const __m256i delta2 = _mm256_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, - qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101, - qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, - qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); -#endif - const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - - const __m256i dot1 = mul_add_epi8(q1b_1, q8b_1); - const __m256i dot2 = mul_add_epi8(q1b_2, q8b_2); - const __m256i dot3 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_1, delta1)); - const __m256i dot4 = _mm256_maddubs_epi16(mone8, _mm256_sign_epi8(q8b_2, delta2)); - - __m256i scale1 = _mm256_shuffle_epi8(scales, scales_idx1); - __m256i scale2 = _mm256_shuffle_epi8(scales, scales_idx2); - - scales_idx1 = _mm256_add_epi8(scales_idx1, mtwo8); - scales_idx2 = _mm256_add_epi8(scales_idx2, mtwo8); - - const __m256i p1 = _mm256_madd_epi16(dot1, scale1); - const __m256i p2 = _mm256_madd_epi16(dot2, scale2); - const __m256i p3 = _mm256_madd_epi16(dot3, scale1); - const __m256i p4 = _mm256_madd_epi16(dot4, scale2); - - sumi1 = _mm256_add_epi32(sumi1, _mm256_add_epi32(p1, p2)); - sumi2 = _mm256_add_epi32(sumi2, _mm256_add_epi32(p3, p4)); - - qs += 8; qh += 4; - } - - const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16)); - - accum1 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi1), accum1); - accum2 = _mm256_fmadd_ps(d, _mm256_cvtepi32_ps(sumi2), accum2); - } - - *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2); - -#elif defined __AVX__ - const __m128i mask = _mm_set1_epi16(0x7); - const __m128i mone = _mm_set1_epi16(1); - - __m256 accum1 = _mm256_setzero_ps(); - __m256 accum2 = _mm256_setzero_ps(); - for (int i = 0; i < nb; ++i) { - - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint8_t * qh = x[i].qh; - const uint16_t * sc = (const uint16_t *)x[i].scales; - - scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); - - __m128i sumi1_0 = _mm_setzero_si128(); - __m128i sumi1_1 = _mm_setzero_si128(); - __m128i sumi2_0 = _mm_setzero_si128(); - __m128i sumi2_1 = _mm_setzero_si128(); - for (int ib = 0; ib < QK_K/32; ib += 2) { - const __m128i q1b_1_0 = _mm_set_epi64x( - iq1s_grid[qs[1] | (((uint16_t)qh[0] << 4) & 0x700)], iq1s_grid[qs[0] | (((uint16_t)qh[0] << 8) & 0x700)]); - const __m128i q1b_1_1 = _mm_set_epi64x( - iq1s_grid[qs[3] | (((uint16_t)qh[1] << 4) & 0x700)], iq1s_grid[qs[2] | (((uint16_t)qh[1] << 8) & 0x700)]); - const __m128i q1b_2_0 = _mm_set_epi64x( - iq1s_grid[qs[5] | (((uint16_t)qh[2] << 4) & 0x700)], iq1s_grid[qs[4] | (((uint16_t)qh[2] << 8) & 0x700)]); - const __m128i q1b_2_1 = _mm_set_epi64x( - iq1s_grid[qs[7] | (((uint16_t)qh[3] << 4) & 0x700)], iq1s_grid[qs[6] | (((uint16_t)qh[3] << 8) & 0x700)]); - const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - - const __m128i dot1_0 = mul_add_epi8_sse(q1b_1_0, q8b_1_0); - const __m128i dot1_1 = mul_add_epi8_sse(q1b_1_1, q8b_1_1); - const __m128i dot2_0 = mul_add_epi8_sse(q1b_2_0, q8b_2_0); - const __m128i dot2_1 = mul_add_epi8_sse(q1b_2_1, q8b_2_1); - - const __m128i delta1_0 = _mm_set_epi64x(qh[0] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, - qh[0] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); - const __m128i delta1_1 = _mm_set_epi64x(qh[1] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, - qh[1] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); - const __m128i delta2_0 = _mm_set_epi64x(qh[2] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, - qh[2] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); - const __m128i delta2_1 = _mm_set_epi64x(qh[3] & 0x80 ? 0xffffffffffffffff : 0x0101010101010101, - qh[3] & 0x08 ? 0xffffffffffffffff : 0x0101010101010101); - - const __m128i dot3_0 = mul_add_epi8_sse(delta1_0, q8b_1_0); - const __m128i dot3_1 = mul_add_epi8_sse(delta1_1, q8b_1_1); - const __m128i dot4_0 = mul_add_epi8_sse(delta2_0, q8b_2_0); - const __m128i dot4_1 = mul_add_epi8_sse(delta2_1, q8b_2_1); - - __m128i scale1_0 = _mm_set1_epi16(sc[ib/2] >> 0); - __m128i scale1_1 = _mm_set1_epi16(sc[ib/2] >> 3); - __m128i scale2_0 = _mm_set1_epi16(sc[ib/2] >> 6); - __m128i scale2_1 = _mm_set1_epi16(sc[ib/2] >> 9); - - scale1_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_0, mask), 1), mone); - scale1_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale1_1, mask), 1), mone); - scale2_0 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_0, mask), 1), mone); - scale2_1 = _mm_add_epi16(_mm_slli_epi16(_mm_and_si128(scale2_1, mask), 1), mone); - const __m128i p1_0 = _mm_madd_epi16(dot1_0, scale1_0); - const __m128i p1_1 = _mm_madd_epi16(dot1_1, scale1_1); - const __m128i p2_0 = _mm_madd_epi16(dot2_0, scale2_0); - const __m128i p2_1 = _mm_madd_epi16(dot2_1, scale2_1); - const __m128i p3_0 = _mm_madd_epi16(dot3_0, scale1_0); - const __m128i p3_1 = _mm_madd_epi16(dot3_1, scale1_1); - const __m128i p4_0 = _mm_madd_epi16(dot4_0, scale2_0); - const __m128i p4_1 = _mm_madd_epi16(dot4_1, scale2_1); - - sumi1_0 = _mm_add_epi32(sumi1_0, _mm_add_epi32(p1_0, p2_0)); - sumi1_1 = _mm_add_epi32(sumi1_1, _mm_add_epi32(p1_1, p2_1)); - sumi2_0 = _mm_add_epi32(sumi2_0, _mm_add_epi32(p3_0, p4_0)); - sumi2_1 = _mm_add_epi32(sumi2_1, _mm_add_epi32(p3_1, p4_1)); - - qs += 8; qh += 4; - } - - const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(scale.f16)); - - accum1 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi1_1, sumi1_0))), accum1); - accum2 = _mm256_add_ps(_mm256_mul_ps(d, _mm256_cvtepi32_ps(MM256_SET_M128I(sumi2_1, sumi2_0))), accum2); - } - - *s = hsum_float_8(accum1) + IQ1M_DELTA * hsum_float_8(accum2); - -#else - - int sum1[2], sum2[2], delta[4]; - - float sumf = 0; - for (int i = 0; i < nb; i++) { - - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint8_t * qh = x[i].qh; - const uint16_t * sc = (const uint16_t *)x[i].scales; - - scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); - - int sumi1 = 0, sumi2 = 0; - for (int ib = 0; ib < QK_K/32; ++ib) { - delta[0] = qh[0] & 0x08 ? -1 : 1; - delta[1] = qh[0] & 0x80 ? -1 : 1; - delta[2] = qh[1] & 0x08 ? -1 : 1; - delta[3] = qh[1] & 0x80 ? -1 : 1; - sum1[0] = sum1[1] = sum2[0] = sum2[1] = 0; - for (int l = 0; l < 4; ++l) { - const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((uint16_t)qh[l/2] << (8 - 4*(l%2))) & 0x700))); - int lsum1 = 0, lsum2 = 0; - for (int j = 0; j < 8; ++j) { - lsum1 += q8[j] * grid[j]; - lsum2 += q8[j]; - } - q8 += 8; - sum1[l/2] += lsum1; - sum2[l/2] += lsum2*delta[l]; - } - - const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1; - const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1; - - sumi1 += sum1[0] * ls1 + sum1[1] * ls2; - sumi2 += sum2[0] * ls1 + sum2[1] * ls2; - qs += 4; - qh += 2; - } - - sumf += GGML_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2); - } - - *s = sumf; - -#endif -} - -void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - assert(n % QK4_NL == 0); - static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same"); - - const block_iq4_nl * GGML_RESTRICT x = vx; - const block_q8_0 * GGML_RESTRICT y = vy; - - const int nb = n / QK4_NL; - - int ib = 0; - float sumf = 0; - -#if defined __ARM_NEON - const int8x16_t values = vld1q_s8(kvalues_iq4nl); - const uint8x16_t m4b = vdupq_n_u8(0x0f); - uint8x16x2_t q4bits; - int8x16x4_t q4b; - int8x16x4_t q8b; - int32x4_t prod_1, prod_2; - - for (; ib + 1 < nb; ib += 2) { - - q4bits.val[0] = vld1q_u8(x[ib + 0].qs); - q4bits.val[1] = vld1q_u8(x[ib + 1].qs); - q8b.val[0] = vld1q_s8(y[ib + 0].qs); - q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16); - q8b.val[2] = vld1q_s8(y[ib + 1].qs); - q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16); - - q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b)); - q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4)); - q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b)); - q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4)); - - prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]); - prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]); - - sumf += - GGML_FP16_TO_FP32(x[ib+0].d) * GGML_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) + - GGML_FP16_TO_FP32(x[ib+1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2); - } - -#elif defined __AVX2__ - - const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); - const __m128i m4b = _mm_set1_epi8(0x0f); - const __m256i mone = _mm256_set1_epi16(1); - - __m256 accum1 = _mm256_setzero_ps(); - __m256 accum2 = _mm256_setzero_ps(); - for (; ib + 1 < nb; ib += 2) { - const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs); - const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs); - const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs); - const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs); - const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)), - _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b))); - const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)), - _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b))); - const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); - const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); - const __m256i p_1 = _mm256_madd_epi16(p16_1, mone); - const __m256i p_2 = _mm256_madd_epi16(p16_2, mone); - accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)), - _mm256_cvtepi32_ps(p_1), accum1); - accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)), - _mm256_cvtepi32_ps(p_2), accum2); - } - - sumf = hsum_float_8(_mm256_add_ps(accum1, accum2)); - -#elif defined __AVX__ - const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); - const __m128i m4b = _mm_set1_epi8(0x0f); - - __m256 accum = _mm256_setzero_ps(); - for (; ib + 1 < nb; ib += 2) { - const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs); - const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs); - const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs); - const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1); - const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs); - const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1); - - const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)); - const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)); - const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)); - const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)); - - const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1); - const __m256 deltas = quad_fp16_delta_float(x[ib].d, y[ib].d, x[ib + 1].d, y[ib + 1].d); - accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum); - } - - sumf = hsum_float_8(accum); - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector signed int v0 = vec_splats((int32_t)0); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - - const vector signed char values = vec_xl( 0, kvalues_iq4nl); - -#pragma GCC unroll 4 - for (; ib < nb; ++ib) { - __builtin_prefetch(x[ib].qs, 0, 1); - __builtin_prefetch(y[ib].qs, 0, 1); - - - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ib].d)); - vector float vyd = vec_splats(GGML_FP16_TO_FP32(y[ib].d)); - vector float vd = vec_mul(vxd, vyd); - - vector signed char qxs = (vector signed char)vec_xl( 0, x[ib].qs); - vector signed char q4x0 = vec_and(qxs, lowMask); - vector signed char q4x1 = vec_sr(qxs, v4); - - q4x0 = vec_perm(values, values, (vector unsigned char)q4x0); - q4x1 = vec_perm(values, values, (vector unsigned char)q4x1); - - vector signed char q8y0 = vec_xl( 0, y[ib].qs); - vector signed char q8y1 = vec_xl(16, y[ib].qs); - - vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1)); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - - vsumi0 = vec_sum4s(qv0, vsumi0); - vsumi1 = vec_sum4s(qv1, vsumi1); - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - } - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - sumf = vec_extract(vsumf0, 0); - -#elif defined (__loongarch_asx) - - const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0); - const __m128i m4b = __lsx_vreplgr2vr_b(0x0f); - const __m256i mone = __lasx_xvreplgr2vr_h(1); - - __m256 accum1 = (__m256)__lasx_xvldi(0); - __m256 accum2 = (__m256)__lasx_xvldi(0); - for (; ib + 1 < nb; ib += 2) { - const __m128i q4bits_1 = __lsx_vld((const __m128i*)x[ib + 0].qs, 0); - const __m128i q4bits_2 = __lsx_vld((const __m128i*)x[ib + 1].qs, 0); - const __m256i q8b_1 = __lasx_xvld((const __m256i *)y[ib + 0].qs, 0); - const __m256i q8b_2 = __lasx_xvld((const __m256i *)y[ib + 1].qs, 0); - const __m256i q4b_1 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b)), - lsx_shuffle_b(values128, __lsx_vand_v(q4bits_1, m4b))); - const __m256i q4b_2 = lasx_insertf128(lsx_shuffle_b(values128, __lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b)), - lsx_shuffle_b(values128, __lsx_vand_v(q4bits_2, m4b))); - const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); - const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); - const __m256i p_1 = lasx_madd_h(p16_1, mone); - const __m256i p_2 = lasx_madd_h(p16_2, mone); - accum1 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib + 0].d)*GGML_FP16_TO_FP32(x[ib + 0].d)), - __lasx_xvffint_s_w(p_1), accum1); - accum2 = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(y[ib + 1].d)*GGML_FP16_TO_FP32(x[ib + 1].d)), - __lasx_xvffint_s_w(p_2), accum2); - } - - sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2)); - -#elif defined(__VXE__) || defined(__VXE2__) - const int8x16_t v_k = vec_xl(0, kvalues_iq4nl); - const uint8x16_t v_m = vec_splat_u8(0x0F); - - for (; ib < nb; ++ib) { - const block_iq4_nl * GGML_RESTRICT x0 = &x[ib]; - const block_q8_0 * GGML_RESTRICT y0 = &y[ib]; - - const uint8x16_t v_x = vec_xl(0, x0->qs); - int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m); - int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4); - - v_xl = vec_perm(v_k, v_k, (uchar8x16_t)v_xl); - v_xh = vec_perm(v_k, v_k, (uchar8x16_t)v_xh); - - const int8x16_t v_yl = vec_xl(0 , y0->qs); - const int8x16_t v_yh = vec_xl(QK8_0/2, y0->qs); - const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh); - - sumf += GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d) * (v_xy[0] + v_xy[1] + v_xy[2] + v_xy[3]); - } -#endif - for (; ib < nb; ++ib) { - const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d); - int sumi1 = 0, sumi2 = 0; - for (int j = 0; j < QK4_NL/2; ++j) { - sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf]; - sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4]; - } - sumf += d * (sumi1 + sumi2); - } - *s = sumf; -} - -void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - assert(n % QK_K == 0); - - const block_iq4_xs * GGML_RESTRICT x = vx; - const block_q8_K * GGML_RESTRICT y = vy; - - const int nb = n / QK_K; - -#if defined __ARM_NEON - const int8x16_t values = vld1q_s8(kvalues_iq4nl); - const uint8x16_t m4b = vdupq_n_u8(0x0f); - ggml_uint8x16x2_t q4bits; - ggml_int8x16x4_t q4b; - ggml_int8x16x4_t q8b; - int32x4_t prod_1, prod_2; - - float sumf = 0; - - for (int ibl = 0; ibl < nb; ++ibl) { - - const int8_t * q8 = y[ibl].qs; - const uint8_t * q4 = x[ibl].qs; - uint16_t h = x[ibl].scales_h; - - int sumi1 = 0, sumi2 = 0; - for (int ib = 0; ib < QK_K/64; ++ib) { - - q4bits = ggml_vld1q_u8_x2(q4); q4 += 32; - q8b = ggml_vld1q_s8_x4(q8); q8 += 64; - - q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b)); - q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4)); - q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b)); - q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4)); - - prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]); - prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]); - - int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32; - int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32; - h >>= 4; - sumi1 += vaddvq_s32(prod_1) * ls1; - sumi2 += vaddvq_s32(prod_2) * ls2; - - } - - sumf += GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2); - } - - *s = sumf; - -#elif defined __AVX2__ - - const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); - const __m128i m4b = _mm_set1_epi8(0x0f); - - __m256 accum = _mm256_setzero_ps(); - for (int ibl = 0; ibl < nb; ++ibl) { - const uint8_t * qs = x[ibl].qs; - const int8_t * q8 = y[ibl].qs; - uint16_t sh = x[ibl].scales_h; - __m256i sumi1 = _mm256_setzero_si256(); - __m256i sumi2 = _mm256_setzero_si256(); - for (int ib = 0; ib < QK_K/32; ib += 2) { - const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)qs); qs += 16; - const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)qs); qs += 16; - const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; - const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)), - _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b))); - const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)), - _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b))); - const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); - const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); - const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; - const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32; - sh >>= 4; - const __m256i p_1 = _mm256_madd_epi16(p16_1, _mm256_set1_epi16(ls1)); - const __m256i p_2 = _mm256_madd_epi16(p16_2, _mm256_set1_epi16(ls2)); - sumi1 = _mm256_add_epi32(p_1, sumi1); - sumi2 = _mm256_add_epi32(p_2, sumi2); - } - accum = _mm256_fmadd_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d), - _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accum); - } - - *s = hsum_float_8(accum); - -#elif defined __AVX__ - const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_iq4nl); - const __m128i m4b = _mm_set1_epi8(0x0f); - - __m256 accum = _mm256_setzero_ps(); - for (int ibl = 0; ibl < nb; ++ibl) { - const uint8_t * qs = x[ibl].qs; - const int8_t * q8 = y[ibl].qs; - uint16_t sh = x[ibl].scales_h; - __m128i sumi1_0 = _mm_setzero_si128(); - __m128i sumi1_1 = _mm_setzero_si128(); - __m128i sumi2_0 = _mm_setzero_si128(); - __m128i sumi2_1 = _mm_setzero_si128(); - for (int ib = 0; ib < QK_K/32; ib += 2) { - const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)qs); qs += 16; - const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)qs); qs += 16; - const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)q8); q8 += 16; - const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)); - const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)); - const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)); - const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)); - const __m128i p16_1_0 = mul_add_epi8_sse(q4b_1_0, q8b_1_0); - const __m128i p16_1_1 = mul_add_epi8_sse(q4b_1_1, q8b_1_1); - const __m128i p16_2_0 = mul_add_epi8_sse(q4b_2_0, q8b_2_0); - const __m128i p16_2_1 = mul_add_epi8_sse(q4b_2_1, q8b_2_1); - const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; - const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32; - sh >>= 4; - const __m128i p_1_0 = _mm_madd_epi16(p16_1_0, _mm_set1_epi16(ls1)); - const __m128i p_1_1 = _mm_madd_epi16(p16_1_1, _mm_set1_epi16(ls1)); - const __m128i p_2_0 = _mm_madd_epi16(p16_2_0, _mm_set1_epi16(ls2)); - const __m128i p_2_1 = _mm_madd_epi16(p16_2_1, _mm_set1_epi16(ls2)); - sumi1_0 = _mm_add_epi32(p_1_0, sumi1_0); - sumi1_1 = _mm_add_epi32(p_1_1, sumi1_1); - sumi2_0 = _mm_add_epi32(p_2_0, sumi2_0); - sumi2_1 = _mm_add_epi32(p_2_1, sumi2_1); - } - __m128i sumi12_0 = _mm_add_epi32(sumi1_0, sumi2_0); - __m128i sumi12_1 = _mm_add_epi32(sumi1_1, sumi2_1); - accum = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d), - _mm256_cvtepi32_ps(MM256_SET_M128I(sumi12_1, sumi12_0))), accum); - } - - *s = hsum_float_8(accum); - -#elif defined(__POWER9_VECTOR__) - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector int v0 = vec_splats((int32_t)0); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - - vector float vsumf0 = vec_splats(0.0f); - vector float vsumf1 = vec_splats(0.0f); - vector float vsumf2 = vec_splats(0.0f); - vector float vsumf3 = vec_splats(0.0f); - - const vector signed char values = vec_xl( 0, kvalues_iq4nl); - - for (int ibl = 0; ibl < nb; ++ibl) { - - vector float vxd = vec_splats(GGML_FP16_TO_FP32(x[ibl].d)); - vector float vyd = vec_splats(y[ibl].d); - vector float vd = vec_mul(vxd, vyd); - - vector signed int vsumi0 = v0; - vector signed int vsumi1 = v0; - vector signed int vsumi2 = v0; - vector signed int vsumi3 = v0; - - uint16_t h = x[ibl].scales_h; - - const uint8_t * GGML_RESTRICT q4 = x[ibl].qs; - const uint8_t * GGML_RESTRICT sc = x[ibl].scales_l; - const int8_t * GGML_RESTRICT q8 = y[ibl].qs; - - for (int ib = 0; ib < QK_K/64; ib ++ ) { - __builtin_prefetch(q4, 0, 1); - __builtin_prefetch(q8, 0, 1); - - vector signed char qxs0 = (vector signed char)vec_xl( 0, q4); - vector signed char qxs1 = (vector signed char)vec_xl(16, q4); - q4 += 32; - - vector signed char q4x00 = (vector signed char)vec_and(qxs0, lowMask); - vector signed char q4x01 = (vector signed char)vec_sr(qxs0, v4); - vector signed char q4x10 = (vector signed char)vec_and(qxs1, lowMask); - vector signed char q4x11 = (vector signed char)vec_sr(qxs1, v4); - - q4x00 = vec_perm(values, values, (vector unsigned char)q4x00); - q4x01 = vec_perm(values, values, (vector unsigned char)q4x01); - q4x10 = vec_perm(values, values, (vector unsigned char)q4x10); - q4x11 = vec_perm(values, values, (vector unsigned char)q4x11); - - vector signed char q8y0 = vec_xl( 0, q8); - vector signed char q8y1 = vec_xl(16, q8); - vector signed char q8y2 = vec_xl(32, q8); - vector signed char q8y3 = vec_xl(48, q8); - q8 += 64; - - vector signed short qv0 = vec_add(vec_mule(q4x00, q8y0), vec_mulo(q4x00, q8y0)); - vector signed short qv1 = vec_add(vec_mule(q4x01, q8y1), vec_mulo(q4x01, q8y1)); - vector signed short qv2 = vec_add(vec_mule(q4x10, q8y2), vec_mulo(q4x10, q8y2)); - vector signed short qv3 = vec_add(vec_mule(q4x11, q8y3), vec_mulo(q4x11, q8y3)); - - const uint16_t ls0 = (uint16_t)(((sc[0] & 0xf) | ((h << 4) & 0x30)) - 32); - const uint16_t ls1 = (uint16_t)(((sc[0] >> 4) | ((h << 2) & 0x30)) - 32); - h >>= 4; - sc ++; - - vector signed short vscales01 = vec_splats((int16_t)ls0); - vector signed short vscales23 = vec_splats((int16_t)ls1); - - vsumi0 = vec_msum(qv0, vscales01, vsumi0); - vsumi1 = vec_msum(qv1, vscales01, vsumi1); - vsumi2 = vec_msum(qv2, vscales23, vsumi2); - vsumi3 = vec_msum(qv3, vscales23, vsumi3); - } - - vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vd, vsumf0); - vsumf1 = vec_madd(vec_ctf(vsumi1, 0), vd, vsumf1); - vsumf2 = vec_madd(vec_ctf(vsumi2, 0), vd, vsumf2); - vsumf3 = vec_madd(vec_ctf(vsumi3, 0), vd, vsumf3); - } - - vsumf0 = vec_add(vsumf0, vsumf2); - vsumf1 = vec_add(vsumf1, vsumf3); - - vsumf0 = vec_add(vsumf0, vsumf1); - - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4)); - vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8)); - - *s = vec_extract(vsumf0, 0); - -#elif defined(__loongarch_asx) - - const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0); - - __m256 accum = (__m256)__lasx_xvldi(0); - - for (int ibl = 0; ibl < nb; ++ibl) { - const uint8_t * qs = x[ibl].qs; - const int8_t * q8 = y[ibl].qs; - uint16_t sh = x[ibl].scales_h; - __m256i sumi1 = __lasx_xvldi(0); - __m256i sumi2 = __lasx_xvldi(0); - for (int ib = 0; ib < QK_K/32; ib += 2) { - const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16; - const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16; - const __m256i q8b_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q8b_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32; - const __m256i q4b_1 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_1, 4)), - __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_1, 0xf))); - const __m256i q4b_2 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_2, 4)), - __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_2, 0xf))); - const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1); - const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); - const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32; - const int16_t ls2 = ((x[ibl].scales_l[ib/2] >> 4) | ((sh << 2) & 0x30)) - 32; - sh >>= 4; - const __m256i p_1 = lasx_madd_h(p16_1, __lasx_xvreplgr2vr_h(ls1)); - const __m256i p_2 = lasx_madd_h(p16_2, __lasx_xvreplgr2vr_h(ls2)); - sumi1 = __lasx_xvadd_w(p_1, sumi1); - sumi2 = __lasx_xvadd_w(p_2, sumi2); - } - accum = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(GGML_FP16_TO_FP32(x[ibl].d)*y[ibl].d), - __lasx_xvffint_s_w(__lasx_xvadd_w(sumi1, sumi2)), accum); - } - - *s = hsum_float_8(accum); -#elif defined(__VXE__) || defined(__VXE2__) - const int8x16_t v_k = vec_xl(0, kvalues_iq4nl); - const uint8x16_t v_m = vec_splat_u8(0x0F); - - float sumf = 0; - - for (int ibl = 0; ibl < nb; ++ibl) { - const uint8_t * GGML_RESTRICT q4 = x[ibl].qs; - const int8_t * GGML_RESTRICT q8 = y[ibl].qs; - - uint16_t h = x[ibl].scales_h; - - int sumi1 = 0, sumi2 = 0; - for (int ib = 0; ib < QK_K/64; ++ib) { - const uint8x16_t v_x0 = vec_xl(0 , q4); - const uint8x16_t v_x1 = vec_xl(QK4_NL/2, q4); - q4 += 32; - - int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m); - int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4); - int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m); - int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4); - - v_x0l = vec_perm(v_k, v_k, (uchar8x16_t)v_x0l); - v_x0h = vec_perm(v_k, v_k, (uchar8x16_t)v_x0h); - v_x1l = vec_perm(v_k, v_k, (uchar8x16_t)v_x1l); - v_x1h = vec_perm(v_k, v_k, (uchar8x16_t)v_x1h); - - const int8x16_t v_y0 = vec_xl( 0, q8); - const int8x16_t v_y1 = vec_xl(16, q8); - const int8x16_t v_y2 = vec_xl(32, q8); - const int8x16_t v_y3 = vec_xl(48, q8); - q8 += 64; - - int32x4_t vsumi0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0l, v_y0), v_x0h, v_y1); - int32x4_t vsumi1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1l, v_y2), v_x1h, v_y3); - - int ls1 = ((x[ibl].scales_l[ib] & 0xF) | ((h << 4) & 0x30)) - 32; - int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32; - - h >>= 4; - - sumi1 += (vsumi0[0] + vsumi0[1] + vsumi0[2] + vsumi0[3]) * ls1; - sumi2 += (vsumi1[0] + vsumi1[1] + vsumi1[2] + vsumi1[3]) * ls2; - } - - sumf += GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2); - } - - *s = sumf; - -#else - float sumf = 0; - for (int ibl = 0; ibl < nb; ++ibl) { - const float d4d8 = GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d; - uint16_t h = x[ibl].scales_h; - const uint8_t * qs = x[ibl].qs; - const int8_t * q8 = y[ibl].qs; - for (int ib = 0; ib < QK_K/32; ib += 2) { - const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30); - const uint8_t ls2 = (x[ibl].scales_l[ib/2] >> 4) | ((h << 2) & 0x30); - h >>= 4; - const float d1 = d4d8*(ls1 - 32); - const float d2 = d4d8*(ls2 - 32); - int sumi1 = 0, sumi2 = 0; - for (int j = 0; j < 16; ++j) { - sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; - sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; - } - sumf += d1 * (sumi1 + sumi2); - qs += 16; - q8 += 32; - sumi1 = sumi2 = 0; - for (int j = 0; j < 16; ++j) { - sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; - sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; - } - sumf += d2 * (sumi1 + sumi2); - qs += 16; - q8 += 32; - } - } - *s = sumf; -#endif -} - -// ============================ 4-bit non-linear quants - -void quantize_row_iq4_nl(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { - assert(k % QK4_NL == 0); - quantize_row_iq4_nl_ref(x, y, k); -} - -void quantize_row_iq4_xs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { - assert(k % QK_K == 0); - quantize_iq4_xs(x, y, 1, k, NULL); -} diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 504003287..2c12e493b 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -3,11 +3,11 @@ #include "ggml-backend-impl.h" #include "ggml-backend.h" -#include "ggml-cpu-traits.h" +#include "traits.h" #include "ggml-cpu-impl.h" #include "ggml-cpu.h" #include "ggml-impl.h" -#include "ggml-cpu-quants.h" +#include "quants.h" #include "ggml-threading.h" #include "unary-ops.h" #include "binary-ops.h" @@ -50,19 +50,6 @@ #include "llamafile/sgemm.h" #endif -#if defined(_MSC_VER) -// disable "possible loss of data" to avoid hundreds of casts -// we should just be careful :) -#pragma warning(disable: 4244 4267) - -// disable POSIX deprecation warnings -// these functions are never going away, anyway -#pragma warning(disable: 4996) - -// unreachable code because of multiple instances of code after GGML_ABORT -#pragma warning(disable: 4702) -#endif - // Note: once we move threading into a separate C++ file // will use std::hardware_destructive_interference_size instead of hardcoding it here // and we'll use C++ attribute syntax. @@ -215,7 +202,7 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .nrows = 1, }, [GGML_TYPE_F16] = { - .from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row, + .from_float = (ggml_from_float_t) ggml_cpu_fp32_to_fp16, .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16, .vec_dot_type = GGML_TYPE_F16, .nrows = 1, @@ -283,7 +270,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .from_float = quantize_row_q4_K, .vec_dot = ggml_vec_dot_q4_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, +#if defined (__ARM_FEATURE_MATMUL_INT8) + .nrows = 2, +#else .nrows = 1, +#endif }, [GGML_TYPE_Q5_K] = { .from_float = quantize_row_q5_K, @@ -295,7 +286,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .from_float = quantize_row_q6_K, .vec_dot = ggml_vec_dot_q6_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, +#if defined (__ARM_FEATURE_MATMUL_INT8) + .nrows = 2, +#else .nrows = 1, +#endif }, [GGML_TYPE_IQ2_XXS] = { .from_float = NULL, @@ -356,7 +351,7 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .from_float = quantize_row_q8_K, }, [GGML_TYPE_BF16] = { - .from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row, + .from_float = (ggml_from_float_t) ggml_cpu_fp32_to_bf16, .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16, .vec_dot_type = GGML_TYPE_BF16, .nrows = 1, @@ -564,6 +559,14 @@ void ggml_barrier(struct ggml_threadpool * tp) { #endif } +void ggml_threadpool_chunk_set(struct ggml_threadpool * tp, int value) { + atomic_store_explicit(&tp->current_chunk, value, memory_order_relaxed); +} + +int ggml_threadpool_chunk_add(struct ggml_threadpool * tp, int value) { + return atomic_fetch_add_explicit(&tp->current_chunk, value, memory_order_relaxed); +} + #if defined(__gnu_linux__) static cpu_set_t ggml_get_numa_affinity(void) { cpu_set_t cpuset; @@ -1932,6 +1935,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_im2col_back_f32(params, tensor); } break; + case GGML_OP_CONV_2D_DW: + { + ggml_compute_forward_conv_2d_dw(params, tensor); + } break; case GGML_OP_CONV_TRANSPOSE_2D: { ggml_compute_forward_conv_transpose_2d(params, tensor); @@ -2207,6 +2214,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } break; case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: { @@ -2268,6 +2276,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } break; case GGML_OP_IM2COL: case GGML_OP_IM2COL_BACK: + case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_CONV_TRANSPOSE_2D: { @@ -2417,12 +2426,32 @@ static bool ggml_thread_apply_priority(int32_t prio) { // This is up to the applications. DWORD p = THREAD_PRIORITY_NORMAL; switch (prio) { + case GGML_SCHED_PRIO_LOW: p = THREAD_PRIORITY_BELOW_NORMAL; break; case GGML_SCHED_PRIO_NORMAL: p = THREAD_PRIORITY_NORMAL; break; case GGML_SCHED_PRIO_MEDIUM: p = THREAD_PRIORITY_ABOVE_NORMAL; break; case GGML_SCHED_PRIO_HIGH: p = THREAD_PRIORITY_HIGHEST; break; case GGML_SCHED_PRIO_REALTIME: p = THREAD_PRIORITY_TIME_CRITICAL; break; } + if (prio != GGML_SCHED_PRIO_LOW) { + // Tell Windows that this thread should not be throttled (needs its own CPU core). + // Newer Windows 11 versions aggresively park (offline) CPU cores and often place + // all our threads onto the first 4 cores which results in terrible performance with + // n_threads > 4 + #if _WIN32_WINNT >= 0x0602 + THREAD_POWER_THROTTLING_STATE t; + ZeroMemory(&t, sizeof(t)); + t.Version = THREAD_POWER_THROTTLING_CURRENT_VERSION; + t.ControlMask = THREAD_POWER_THROTTLING_EXECUTION_SPEED; + t.StateMask = 0; + + if (!SetThreadInformation(GetCurrentThread(), ThreadPowerThrottling, &t, sizeof(t))) { + GGML_LOG_DEBUG("failed to disable thread power throttling %d : (%d)\n", prio, (int) GetLastError()); + return false; + } + #endif + } + if (prio == GGML_SCHED_PRIO_NORMAL) { // Keep inherited policy/priority return true; @@ -2450,6 +2479,8 @@ static bool ggml_thread_apply_priority(int32_t prio) { struct sched_param p; int32_t policy = SCHED_OTHER; switch (prio) { + // TODO: there seems to be no way to set lower prio on Apple platforms + case GGML_SCHED_PRIO_LOW: policy = SCHED_OTHER; p.sched_priority = 0; break; case GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break; case GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break; case GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break; @@ -2506,6 +2537,7 @@ static bool ggml_thread_apply_priority(int32_t prio) { struct sched_param p; int32_t policy = SCHED_OTHER; switch (prio) { + case GGML_SCHED_PRIO_LOW: policy = SCHED_BATCH; p.sched_priority = 0; break; case GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break; case GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break; case GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break; @@ -3161,6 +3193,93 @@ enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct g return ggml_graph_compute(cgraph, &cplan); } +void ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) { + int64_t i = 0; +#if defined(__F16C__) +#if defined(__AVX512F__) + for (; i + 15 < n; i += 16) { + __m512 x_vec = _mm512_loadu_ps(x + i); + __m256i y_vec = _mm512_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); + _mm256_storeu_si256((__m256i *)(y + i), y_vec); + } +#endif + for (; i + 7 < n; i += 8) { + __m256 x_vec = _mm256_loadu_ps(x + i); + __m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); + _mm_storeu_si128((__m128i *)(y + i), y_vec); + } + for (; i + 3 < n; i += 4) { + __m128 x_vec = _mm_loadu_ps(x + i); + __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); + _mm_storel_epi64((__m128i *)(y + i), y_vec); + } +#endif + for (; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(x[i]); + } +} + +void ggml_cpu_fp16_to_fp32(const ggml_fp16_t * x, float * y, int64_t n) { + int64_t i = 0; +#if defined(__F16C__) +#if defined(__AVX512F__) + for (; i + 15 < n; i += 16) { + __m256i x_vec = _mm256_loadu_si256((const __m256i *)(x + i)); + __m512 y_vec = _mm512_cvtph_ps(x_vec); + _mm512_storeu_ps(y + i, y_vec); + } +#endif + for (; i + 7 < n; i += 8) { + __m128i x_vec = _mm_loadu_si128((const __m128i *)(x + i)); + __m256 y_vec = _mm256_cvtph_ps(x_vec); + _mm256_storeu_ps(y + i, y_vec); + } + for (; i + 3 < n; i += 4) { + __m128i x_vec = _mm_loadl_epi64((const __m128i *)(x + i)); + __m128 y_vec = _mm_cvtph_ps(x_vec); + _mm_storeu_ps(y + i, y_vec); + } +#endif + for (; i < n; ++i) { + y[i] = GGML_FP16_TO_FP32(x[i]); + } +} + +void ggml_cpu_fp32_to_bf16(const float * x, ggml_bf16_t * y, int64_t n) { + int64_t i = 0; + for (; i < n; ++i) { + y[i] = GGML_FP32_TO_BF16(x[i]); + } +} + +void ggml_cpu_bf16_to_fp32(const ggml_bf16_t * x, float * y, int64_t n) { + int64_t i = 0; +#if defined(__AVX2__) +#if defined(__AVX512F__) + for (; i + 15 < n; i += 16) { + _mm512_storeu_ps(y + i, + _mm512_castsi512_ps( + _mm512_slli_epi32( + _mm512_cvtepu16_epi32( + _mm256_loadu_si256( + (const __m256i *)(x + i))), + 16))); + } +#endif + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(y + i, + _mm256_castsi256_ps( + _mm256_slli_epi32( + _mm256_cvtepu16_epi32( + _mm_loadu_si128( + (const __m128i *)(x + i))), + 16))); + } +#endif + for (; i < n; i++) { + y[i] = GGML_BF16_TO_FP32(x[i]); + } +} int ggml_cpu_has_avx(void) { #if defined(__AVX__) @@ -3400,6 +3519,19 @@ void ggml_cpu_init(void) { const uint64_t t_end = ggml_time_us(); UNUSED(t_end); GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0); + +#ifdef GGML_USE_OPENMP + //if (!getenv("OMP_WAIT_POLICY")) { + // // set the wait policy to active, so that OpenMP threads don't sleep + // putenv("OMP_WAIT_POLICY=active"); + //} + + if (!getenv("KMP_BLOCKTIME")) { + // set the time to wait before sleeping a thread + // this is less aggressive than setting the wait policy to active, but should achieve similar results in most cases + putenv("KMP_BLOCKTIME=200"); // 200ms + } +#endif } #if defined(__ARM_ARCH) diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 09f8382b9..735ef3f01 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -1,8 +1,8 @@ #include "ggml-backend.h" #include "ggml-backend-impl.h" #include "ggml-cpu.h" -#include "ggml-cpu-aarch64.h" -#include "ggml-cpu-traits.h" +#include "repack.h" +#include "traits.h" #include "ggml-impl.h" #include "amx/amx.h" @@ -11,24 +11,26 @@ #include #ifdef GGML_USE_CPU_HBM -#include "ggml-cpu-hbm.h" +# include "hbm.h" #endif #ifdef GGML_USE_CPU_KLEIDIAI -#include "kleidiai/kleidiai.h" -#endif - -#if defined(__APPLE__) -#include -#include +# include "kleidiai/kleidiai.h" #endif #if defined(_WIN32) -#define WIN32_LEAN_AND_MEAN -#ifndef NOMINMAX - #define NOMINMAX +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +#else +# include #endif -#include + +#if defined(__APPLE__) +# include +# include #endif // ggml-backend interface @@ -49,9 +51,9 @@ std::vector& ggml_backend_cpu_get_extra_buffers_type } #endif -#ifdef GGML_USE_CPU_AARCH64 - if (ggml_backend_cpu_aarch64_buffer_type()) { - bufts.push_back(ggml_backend_cpu_aarch64_buffer_type()); +#ifdef GGML_USE_CPU_REPACK + if (ggml_backend_cpu_repack_buffer_type()) { + bufts.push_back(ggml_backend_cpu_repack_buffer_type()); } #endif @@ -70,8 +72,10 @@ static ggml_backend_buffer_type_t * ggml_backend_cpu_device_get_extra_buffers_ty } static bool ggml_backend_cpu_is_extra_buffer_type(ggml_backend_buffer_type_t buft) { - for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) { - if (extra && extra == buft) return true; + for (auto * extra : ggml_backend_cpu_get_extra_buffers_type()) { + if (extra && extra == buft) { + return true; + } } return false; } @@ -330,9 +334,18 @@ static const char * ggml_backend_cpu_device_get_description(ggml_backend_dev_t d } static void ggml_backend_cpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { - // TODO - *free = 0; - *total = 0; +#ifdef _WIN32 + MEMORYSTATUSEX status; + status.dwLength = sizeof(status); + GlobalMemoryStatusEx(&status); + *total = status.ullTotalPhys; + *free = status.ullAvailPhys; +#else + long pages = sysconf(_SC_PHYS_PAGES); + long page_size = sysconf(_SC_PAGE_SIZE); + *total = pages * page_size; + *free = *total; +#endif GGML_UNUSED(dev); } @@ -425,6 +438,8 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st } case GGML_OP_IM2COL_BACK: return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; + case GGML_OP_GET_ROWS_BACK: + return src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16; case GGML_OP_OUT_PROD: return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) && src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; @@ -581,8 +596,8 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r #ifdef GGML_USE_CPU_KLEIDIAI features.push_back({ "KLEIDIAI", "1" }); #endif - #ifdef GGML_USE_CPU_AARCH64 - features.push_back({ "AARCH64_REPACK", "1" }); + #ifdef GGML_USE_CPU_REPACK + features.push_back({ "REPACK", "1" }); #endif features.push_back({ nullptr, nullptr }); diff --git a/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp b/ggml/src/ggml-cpu/hbm.cpp similarity index 98% rename from ggml/src/ggml-cpu/ggml-cpu-hbm.cpp rename to ggml/src/ggml-cpu/hbm.cpp index fa8dea2af..a4073c15e 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +++ b/ggml/src/ggml-cpu/hbm.cpp @@ -5,7 +5,7 @@ #include "ggml-cpu.h" #include "ggml-impl.h" -#include "ggml-cpu-hbm.h" +#include "hbm.h" // buffer type HBM diff --git a/ggml/src/ggml-cpu/ggml-cpu-hbm.h b/ggml/src/ggml-cpu/hbm.h similarity index 100% rename from ggml/src/ggml-cpu/ggml-cpu-hbm.h rename to ggml/src/ggml-cpu/hbm.h diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp index aacc2bb5e..910fd0ee4 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp @@ -4,16 +4,22 @@ // KleidiAI micro-kernels #include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h" -#include "kai_lhs_quant_pack_qsi8d32p_f32.h" -#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h" -#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" -#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" #include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h" #include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h" #include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h" #include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h" #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h" +#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" + +#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h" +#include "kai_lhs_quant_pack_qsi8d32p_f32.h" +#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h" + +#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h" +#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" +#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" + #include "kai_common.h" #include "kernels.h" @@ -61,6 +67,53 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, }, /* .required_cpu = */ CPU_FEATURE_SME, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_Q4_0, + /* .op_type = */ GGML_TYPE_F32, + }, + { + /* SME GEMM */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + }, + /* SME GEMV */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_lhs_offset = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, + }, + /* .lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme, + /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme, + /* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme, + /* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme, + }, + /* .rhs_info = */ { + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, + /* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, + }, + /* .required_cpu = */ CPU_FEATURE_SME, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_F16, + /* .op_type = */ GGML_TYPE_F32, }, #endif #if defined(__APPLE__) @@ -105,6 +158,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_Q4_0, + /* .op_type = */ GGML_TYPE_F32, }, #endif #if defined(__ARM_FEATURE_MATMUL_INT8) @@ -148,6 +204,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_Q4_0, + /* .op_type = */ GGML_TYPE_F32, }, #endif #else @@ -192,6 +251,9 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_Q4_0, + /* .op_type = */ GGML_TYPE_F32, }, #endif #if defined(__ARM_FEATURE_DOTPROD) @@ -235,12 +297,33 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_Q4_0, + /* .op_type = */ GGML_TYPE_F32, }, #endif #endif }; -ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature features) { +ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) { + ggml_kleidiai_kernels * kernel = nullptr; + + if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) { + for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { + if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu && + gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type && + gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type && + gemm_gemv_kernels[i].op_type == tensor->type) { + kernel = &gemm_gemv_kernels[i]; + break; + } + } + } + + return kernel; +} + +ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) { ggml_kleidiai_kernels * kernels = nullptr; for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.h b/ggml/src/ggml-cpu/kleidiai/kernels.h index 2ffe97eb4..3b268d4a2 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.h +++ b/ggml/src/ggml-cpu/kleidiai/kernels.h @@ -4,6 +4,10 @@ #pragma once +#include +#include +#include "ggml.h" + enum cpu_feature { CPU_FEATURE_NONE = 0, CPU_FEATURE_DOTPROD = 1, @@ -26,26 +30,53 @@ struct kernel_info { size_t (*get_nr)(void); size_t (*get_kr)(void); size_t (*get_sr)(void); - size_t (*get_lhs_offset)(size_t m_idx, size_t k, size_t bl); - size_t (*get_rhs_packed_offset)(size_t n_idx, size_t k, size_t bl); + std::variant< + std::function, + std::function + > get_lhs_offset; + std::variant< + std::function, + std::function + > get_rhs_packed_offset; size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride); size_t (*get_dst_size)(size_t m, size_t n); - void (*run_kernel)(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, - float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max); + std::variant< + std::function, + std::function + > run_kernel; }; struct lhs_packing_info { size_t (*get_offset)(size_t m_idx, size_t lhs_stride); - size_t (*get_packed_offset)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); - size_t (*packed_size)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); - void (*pack_func)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, - size_t lhs_stride, void* lhs_packed); + std::variant< + std::function, + std::function + > get_packed_offset; + std::variant< + std::function, + std::function + > packed_size; + std::variant< + std::function, + std::function + > pack_func; }; struct rhs_packing_info { - size_t (*packed_size)(size_t n, size_t k, size_t nr, size_t kr, size_t bl); - void (*pack_func)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs, - const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params); + std::variant< + std::function, + std::function + > packed_size; + std::variant< + std::function, + std::function + > pack_func; }; struct ggml_kleidiai_kernels { @@ -55,6 +86,10 @@ struct ggml_kleidiai_kernels { rhs_packing_info rhs_info; cpu_feature required_cpu; + ggml_type lhs_type; + ggml_type rhs_type; + ggml_type op_type; }; -ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features); +ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor); +ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features); diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index 4e89ca0fa..fafe45e6c 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -3,7 +3,9 @@ // #include #include +#include #include +#include #include #include #if defined(__linux__) @@ -24,7 +26,7 @@ #include "ggml-impl.h" #include "ggml-backend-impl.h" #include "ggml-threading.h" -#include "ggml-cpu-traits.h" +#include "traits.h" #include "kernels.h" @@ -34,8 +36,9 @@ #include "ggml-common.h" struct ggml_kleidiai_context { + cpu_feature features; ggml_kleidiai_kernels * kernels; -} static ctx = { NULL }; +} static ctx = { CPU_FEATURE_NONE, NULL }; static void init_kleidiai_context(void) { @@ -47,18 +50,18 @@ static void init_kleidiai_context(void) { const char *env_var = getenv("GGML_KLEIDIAI_SME"); int sme_enabled = 0; - cpu_feature features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) | - (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) | - (ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE); + ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) | + (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) | + (ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE); if (env_var) { sme_enabled = atoi(env_var); } if (sme_enabled != 0) { - features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE; + ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE; } - ctx.kernels = ggml_kleidiai_select_kernels(features); + ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features); } ggml_critical_section_end(); } @@ -68,95 +71,275 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) { return tensor->ne[dim]; } +template +static Ret variant_call(const Variant & var, Args&&... args) { + return std::visit([&](auto&& func) -> Ret { + if constexpr (std::is_invocable_r_v) { + return func(std::forward(args)...); + } else { + throw std::runtime_error("Invalid function type in variant_call"); + } + }, var); +} + namespace ggml::cpu::kleidiai { + +static size_t round_down(size_t x, size_t y) { + return y == 0 ? x : x - (x % y); +} + +static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint16_t * src, size_t rhs_stride) { + size_t src_stride = rhs_stride / sizeof(uint16_t); + size_t dst_stride = n; + + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + for (size_t n_idx = 0; n_idx < n; ++n_idx) { + uint16_t v = *(src + k_idx + n_idx * src_stride); + *(dst + n_idx + k_idx * dst_stride) = kai_cast_f32_f16(v); + } + } +} + class tensor_traits : public ggml::cpu::tensor_traits { bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { - GGML_ASSERT(ctx.kernels); - kernel_info * kernel = op->src[1]->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm; + ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op); + GGML_ASSERT(kernels); + kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm; size_t k = op->src[0]->ne[0]; + size_t n = op->src[0]->ne[1]; size_t m = op->src[1]->ne[1]; size_t mr = kernel->get_mr(); size_t kr = kernel->get_kr(); size_t sr = kernel->get_sr(); - size = ctx.kernels->lhs_info.packed_size(m, k, QK4_0, mr, kr, sr); + if (kernels->rhs_type == GGML_TYPE_Q4_0) { + size = variant_call(kernels->lhs_info.packed_size, m, k, QK4_0, mr, kr, sr); + } else if (kernels->rhs_type == GGML_TYPE_F16) { + size = variant_call(kernels->lhs_info.packed_size, m, k, mr, kr, sr) + + variant_call(kernels->rhs_info.packed_size, n, k) + + k * n * sizeof(float) + n * sizeof(float); + } else { + GGML_ASSERT(false); + } return true; } + bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override { if (dst->op == GGML_OP_MUL_MAT) { - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; + if (dst->src[0]->type == GGML_TYPE_Q4_0) { + return compute_forward_q4_0(params, dst); + } else if (dst->src[0]->type == GGML_TYPE_F16) { + return compute_forward_kv_cache(params, dst); + } + } + return false; + } - GGML_TENSOR_BINARY_OP_LOCALS + bool compute_forward_kv_cache(ggml_compute_params * params, struct ggml_tensor * dst) { + static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT; - GGML_ASSERT(ctx.kernels); - kernel_info * kernel = src1->ne[1] == 1 ? &ctx.kernels->gemv : &ctx.kernels->gemm; - lhs_packing_info * lhs_info = &ctx.kernels->lhs_info; + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; - GGML_ASSERT(kernel); + GGML_TENSOR_BINARY_OP_LOCALS - const int ith = params->ith; - const int nth = params->nth; + ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); + GGML_ASSERT(kernels); - const size_t k = ne00; - const size_t m = ne11; - const size_t n = ne01; + kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm; + GGML_ASSERT(kernel); - const size_t n_step = kernel->get_n_step(); - const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step); - const size_t n_start = ith * num_n_per_thread; + const int nth = params->nth; + const int ith = params->ith; - size_t n_to_process = num_n_per_thread; - if ((n_start + n_to_process) > n) { - n_to_process = n - n_start; + const int64_t lhs_batch_size0 = ne12; + const int64_t rhs_batch_size0 = ne02; + const int64_t batch_size = rhs_batch_size0; + + const int64_t r = lhs_batch_size0 / rhs_batch_size0; + + const int64_t m = ne11 * r; + const int64_t n = ne01; + const int64_t k = ne00; + + const size_t lhs_stride = src1->nb[1]; + const size_t rhs_stride = src0->nb[1]; + const size_t dst_stride = dst->nb[1]; + + const int64_t mr = static_cast(kernel->get_mr()); + const int64_t nr = static_cast(kernel->get_nr()); + const int64_t kr = static_cast(kernel->get_kr()); + const int64_t sr = static_cast(kernel->get_sr()); + + const size_t lhs_packed_size = variant_call(kernels->lhs_info.packed_size, m, k, mr, kr, sr); + const size_t rhs_packed_size = variant_call(kernels->rhs_info.packed_size, n, k); + const size_t kxn_size = k * n * sizeof(float); + const size_t bias_size = n * sizeof(float); + + const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size; + GGML_ASSERT(wsize_required <= params->wsize); + + uint8_t * lhs_packed = static_cast(params->wdata); + uint8_t * rhs_packed = lhs_packed + lhs_packed_size; + uint8_t * rhs_kxn = rhs_packed + rhs_packed_size; + uint8_t * bias = rhs_kxn + kxn_size; + + for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + const uint8_t * lhs_batch = static_cast(src1->data) + batch_idx * m * lhs_stride; + const uint8_t * rhs_batch = static_cast(src0->data) + batch_idx * n * rhs_stride; + uint8_t * dst_batch = static_cast(dst->data) + batch_idx * m * dst_stride; + + // LHS packing + { + const int64_t m_roundup_mr = kai_roundup(m, mr); + const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth); + + if (ith < num_threads) { + const int64_t num_m_per_thread0 = round_down(m_roundup_mr / num_threads, mr); + const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0; + + const int64_t m_start = ith * num_m_per_thread0; + const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0; + + const size_t lhs_offset = variant_call(kernels->gemm.get_lhs_offset, m_start, lhs_stride); + const size_t lhs_packed_offset = variant_call(kernels->lhs_info.get_packed_offset, m_start, k, mr, kr, sr); + + const void * src_ptr = static_cast(lhs_batch) + lhs_offset; + void * dst_ptr = static_cast(lhs_packed) + lhs_packed_offset; + + variant_call(kernels->lhs_info.pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr); + } } - const uint8_t * lhs = static_cast(src1->data); - uint8_t * lhs_packed = (uint8_t*)params->wdata; - const uint8_t * rhs_packed = static_cast(src0->data); + // RHS packing + if (first_to_arrive.test_and_set(std::memory_order_acquire) == false) { + // First thread to reach this point handles RHS packing + memset(bias, 0, n * sizeof(float)); + transpose_f32kxn_f16nxk(n, k, reinterpret_cast(rhs_kxn), + reinterpret_cast(rhs_batch), rhs_stride); - size_t mr = kernel->get_mr(); - size_t kr = kernel->get_kr(); - size_t sr = kernel->get_sr(); - - // Calculate number of columns to be processed per thread - const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth; - const size_t m_start = ith * num_m_per_thread; - size_t m_to_process = num_m_per_thread; - if ((m_start + m_to_process) > m) { - m_to_process = m - m_start; - } - - if(m_start < m) { - // Transform LHS - const size_t src_stride = src1->nb[1]; - const float * src_ptr = reinterpret_cast(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1])); - const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, QK4_0, mr, kr, sr); - void * lhs_packed_ptr = static_cast(lhs_packed + lhs_packed_offset); - - lhs_info->pack_func(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr); + variant_call(kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, n * sizeof(float), + rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr); } ggml_barrier(params->threadpool); - // Perform the operation - const size_t dst_stride = dst->nb[1]; - const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, QK4_0, mr, kr, sr); - const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, QK4_0); - const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride); - const void * rhs_ptr = static_cast(rhs_packed + rhs_packed_offset); - const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset); - float *dst_ptr = reinterpret_cast(static_cast(dst->data) + dst_offset); + first_to_arrive.clear(std::memory_order_release); - kernel->run_kernel(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, - dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); - return true; + // Perform the matmul + { + const int64_t m_to_process = m; + const int64_t m_start = 0; + + const int64_t n_step = static_cast(kernel->get_n_step()); + const int64_t num_threads = KAI_MIN(n / n_step, nth); + + if (ith < num_threads) { + const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step); + const int64_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0; + + const int64_t n_start = ith * num_n_per_thread0; + const int64_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0; + + const size_t lhs_packed_offset = variant_call(kernel->get_lhs_offset, m_start, k); + const size_t rhs_packed_offset = variant_call(kernel->get_rhs_packed_offset, n_start, k); + const size_t dst_offset = kernel->get_dst_offset(m_start, n_start, dst_stride); + + const void * lhs_ptr = lhs_packed + lhs_packed_offset; + const void * rhs_ptr = rhs_packed + rhs_packed_offset; + float * dst_ptr = reinterpret_cast(dst_batch + dst_offset); + + variant_call(kernel->run_kernel, m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); + } + } + + if (batch_idx != batch_size - 1) { + // This barrier is necessary when the batch size is larger than 1. While processing a batch, + // the work data buffer (params->wdata) is used as temporary storage which means that only + // a single batch can be processed at any given time. No barrier is needed for the last + // batch since GGML inserts a barrier between the execution of every operator. + ggml_barrier(params->threadpool); + } } - return false; + + return true; + } + + bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); + GGML_ASSERT(kernels); + + kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm; + lhs_packing_info * lhs_info = &kernels->lhs_info; + + GGML_ASSERT(kernel); + + const int ith = params->ith; + const int nth = params->nth; + + const size_t k = ne00; + const size_t m = ne11; + const size_t n = ne01; + + size_t mr = kernel->get_mr(); + size_t kr = kernel->get_kr(); + size_t sr = kernel->get_sr(); + + const uint8_t * lhs = static_cast(src1->data); + uint8_t * lhs_packed = (uint8_t*)params->wdata; + const uint8_t * rhs_packed = static_cast(src0->data); + + const size_t n_step = kernel->get_n_step(); + const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step); + const size_t n_start = ith * num_n_per_thread; + + size_t n_to_process = num_n_per_thread; + if ((n_start + n_to_process) > n) { + n_to_process = n - n_start; + } + + // Calculate number of columns to be processed per thread + const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth; + const size_t m_start = ith * num_m_per_thread; + size_t m_to_process = num_m_per_thread; + if ((m_start + m_to_process) > m) { + m_to_process = m - m_start; + } + + if (m_start < m) { + // Transform LHS + const size_t src_stride = src1->nb[1]; + const float * src_ptr = reinterpret_cast(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1])); + const size_t lhs_packed_offset = variant_call(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr); + void * lhs_packed_ptr = static_cast(lhs_packed + lhs_packed_offset); + + variant_call(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr); + } + + ggml_barrier(params->threadpool); + + // Perform the operation + const size_t dst_stride = dst->nb[1]; + const size_t lhs_packed_offset = variant_call(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr); + const size_t rhs_packed_offset = variant_call(kernel->get_rhs_packed_offset, n_start, k, QK4_0); + const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride); + const void * rhs_ptr = static_cast(rhs_packed + rhs_packed_offset); + const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset); + float *dst_ptr = reinterpret_cast(static_cast(dst->data) + dst_offset); + + variant_call(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, + sizeof(float), -FLT_MAX, FLT_MAX); + + return true; } public: @@ -169,13 +352,13 @@ public: size_t sr = ctx.kernels->gemm.get_sr(); #ifndef NDEBUG - const size_t repacked_size = ctx.kernels->rhs_info.packed_size(n, k, nr, kr, QK4_0); + const size_t repacked_size = variant_call(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0); GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!"); #endif struct kai_rhs_pack_qs4cxs1s0_param params; params.lhs_zero_point = 1; params.rhs_zero_point = 8; - ctx.kernels->rhs_info.pack_func(1, n, k, nr, kr, sr, QK4_0, (const uint8_t *)data, NULL, tensor->data, 0, ¶ms); + variant_call(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, ¶ms); return 0; @@ -189,7 +372,7 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc } } // namespace ggml::cpu::kleidiai -GGML_API enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { +static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor); GGML_UNUSED(buffer); @@ -238,12 +421,11 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b namespace ggml::cpu::kleidiai { class extra_buffer_type : ggml::cpu::extra_buffer_type { bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { - if ( op->op == GGML_OP_MUL_MAT && - op->src[0]->type == GGML_TYPE_Q4_0 && - op->src[0]->buffer && - (ggml_n_dims(op->src[0]) == 2) && - op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels - ) { + if (op->op == GGML_OP_MUL_MAT && + op->src[0]->type == GGML_TYPE_Q4_0 && + op->src[0]->buffer && + (ggml_n_dims(op->src[0]) == 2) && + op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) { if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { return false; } @@ -260,6 +442,19 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) { return (ggml::cpu::tensor_traits *) op->src[0]->extra; } + else if (ggml_kleidiai_select_kernels(ctx.features, op) && + op->src[0]->op == GGML_OP_VIEW && + (op->src[1]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_SOFT_MAX) && + op->src[1]->ne[1] > 1) { + if ((op->src[0]->nb[0] != 2) || + (op->src[1]->nb[0] != 4) || + (op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) || + (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) { + return nullptr; + } + + return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL); + } } return nullptr; } diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index f6374f789..1c545f803 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -53,7 +53,6 @@ #include "ggml-cpu-impl.h" #include "ggml-quants.h" -#include #include #include @@ -394,8 +393,6 @@ class tinyBLAS { template NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) { - static std::atomic current_chunk; - GGML_ASSERT(m % (RM * BM) == 0); const int64_t ytiles = m / (RM * BM); const int64_t xtiles = (n + RN -1) / RN; @@ -410,7 +407,7 @@ class tinyBLAS { if (params->ith == 0) { GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles); // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. - std::atomic_store_explicit(¤t_chunk, (int64_t)params->nth, std::memory_order_relaxed); + ggml_threadpool_chunk_set(params->threadpool, params->nth); } ggml_barrier(params->threadpool); @@ -439,8 +436,7 @@ class tinyBLAS { GGML_ASSERT(jj == jj2); } - // next step. - job = std::atomic_fetch_add_explicit(¤t_chunk, (int64_t)1, std::memory_order_relaxed); + job = ggml_threadpool_chunk_add(params->threadpool, 1); } ggml_barrier(params->threadpool); @@ -1054,6 +1050,493 @@ class tinyBLAS_Q0_AVX { } \ } \ +template +class tinyBLAS_BF16_PPC { + public: + tinyBLAS_BF16_PPC(int64_t k, + const TA *A, int64_t lda, + const TB *B, int64_t ldb, + TC *C, int64_t ldc, + int ith, int nth) + : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + } + + void matmul(int64_t m, int64_t n) { + mnpack(0, m, 0, n); + } + + private: + void vector_permute_store(vec_t *c, int numVec, unsigned char *vecOffset) { + vec_t t[8], s[8]; + vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23}; + vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31}; + vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; + vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; + + if (numVec == 2) { + t[0] = vec_perm(c[0], c[1], swiz1); + t[1] = vec_perm(c[2], c[3], swiz1); + s[0] = vec_perm(t[0], t[1], swiz3); + s[1] = vec_perm(t[0], t[1], swiz4); + vec_xst(s[0], 0, (vec_t*)vecOffset); + vec_xst(s[1], 0, (vec_t*)(vecOffset + 16)); + } else if (numVec == 4) { + t[0] = vec_perm(c[0], c[1], swiz1); + t[1] = vec_perm(c[0], c[1], swiz2); + t[2] = vec_perm(c[2], c[3], swiz1); + t[3] = vec_perm(c[2], c[3], swiz2); + s[0] = vec_perm(t[0], t[2], swiz3); + s[1] = vec_perm(t[0], t[2], swiz4); + s[2] = vec_perm(t[1], t[3], swiz3); + s[3] = vec_perm(t[1], t[3], swiz4); + for (int i = 0; i < 4; ++i) + vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16)); + } else if (numVec == 8) { + for (int i = 0; i < 4; i += 2) { + t[i+0] = vec_perm(c[i+0], c[i+1], swiz1); + t[i+1] = vec_perm(c[i+0], c[i+1], swiz2); + } + for (int i = 4; i < 8; i += 2) { + t[i+0] = vec_perm(c[i+0], c[i+1], swiz1); + t[i+1] = vec_perm(c[i+0], c[i+1], swiz2); + } + s[0] = vec_perm(t[0], t[2], swiz3); + s[1] = vec_perm(t[0], t[2], swiz4); + s[2] = vec_perm(t[1], t[3], swiz3); + s[3] = vec_perm(t[1], t[3], swiz4); + s[4] = vec_perm(t[4], t[6], swiz3); + s[5] = vec_perm(t[4], t[6], swiz4); + s[6] = vec_perm(t[5], t[7], swiz3); + s[7] = vec_perm(t[5], t[7], swiz4); + for (int i = 0; i < 8; ++i) + vec_xst(s[i], 0, (vec_t*)(vecOffset + i * 16)); + } + } + + void packNormal(const TA* a, int64_t lda, int rows, int cols, unsigned char* vec) { + int64_t i, j; + TA *aoffset = NULL; + unsigned char *vecOffset = NULL; + TA * aoffsets[8]; + vector unsigned char c_arr[8]; + aoffset = const_cast(a); + vecOffset = vec; + j = (rows >> 3); + if (j > 0) { + do { + if (cols == 4) { + aoffsets[0] = aoffset; + for (int it = 1; it < 4; ++it) + aoffsets[it] = aoffsets[it-1] + lda; + aoffset += 4 * lda; + for (int i = 0; i < 4; ++i) + c_arr[i] = vec_xl(0, (vector unsigned char*)aoffsets[i]); + vector_permute_store(c_arr, 4, vecOffset); + for (int i = 0; i<4; i++) + aoffsets[i] = aoffsets[i]+lda; + vecOffset +=64; + } + i = (cols >> 3); + if (i > 0) { + aoffsets[0] = aoffset; + for (int it = 1; it < 8; ++it) { + aoffsets[it] = aoffsets[it-1] + lda; + } + aoffset += 8 * lda; + do { + for (int it = 0; it < 8; ++it) + c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]); + vector_permute_store(c_arr, 8, vecOffset); + for (int it = 0; it < 8; ++it) + aoffsets[it] = aoffsets[it] + 8*lda; + vecOffset += 128; + i--; + } while(i > 0); + } + j--; + } while(j > 0); + } + if (rows & 4) { + aoffsets[0] = aoffset; + for (int it = 1; it < 4; ++it) + aoffsets[it] = aoffsets[it-1] + lda; + aoffset += 4 * lda; + if (cols == 4) { + for (int it = 0; it < 4; ++it) + c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]); + vector_permute_store(c_arr, 2, vecOffset); + for (int it = 0; it< 4; it++) + aoffsets[it] = aoffsets[it] + lda; + vecOffset += 32; + } + i = (cols >> 3); + if (i > 0) { + do { + for (int it = 0; it < 4; ++it) + c_arr[it] = vec_xl(0, (vector unsigned char*)aoffsets[it]); + vector_permute_store(c_arr, 4, vecOffset); + for (int it = 0; it< 4; it++) + aoffsets[it] = aoffsets[it] + 8*lda; + vecOffset += 64; + i--; + } while(i > 0); + } + } + if (rows & 3) { + aoffsets[0] = aoffset; + for (int it = 1; it < 4; ++it) + aoffsets[it] = aoffsets[it-1] + lda; + if (cols == 4) { + switch(rows) { + case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]); + case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]); + case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]); + break; + } + vector_permute_store(c_arr, 2, vecOffset); + for (int it = 0; it< 4; it++) + aoffsets[it] = aoffsets[it] + lda; + vecOffset += 32; + } + i = (cols >> 3); + if (i > 0) { + do { + switch(rows) { + case 3: c_arr[2] = vec_xl(0, (vector unsigned char*)aoffsets[2]); + case 2: c_arr[1] = vec_xl(0, (vector unsigned char*)aoffsets[1]); + case 1: c_arr[0] = vec_xl(0, (vector unsigned char*)aoffsets[0]); + break; + } + vector_permute_store(c_arr, 4, vecOffset); + for (int it = 0; it <4; it++) + aoffsets[it] = aoffsets[it] + 8* lda; + vecOffset += 64; + i--; + } while(i > 0); + } + } + } + + void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t mc, nc, mp, np; + int m_rem = MIN(m - m0, 8); + int n_rem = MIN(n - n0, 8); + + if (m_rem >= 8 && n_rem >= 8) { + mc = 8; + nc = 8; + gemm<8,8>(m0, m, n0, n); + } else if (m_rem >= 4 && n_rem >= 8) { + mc = 4; + nc = 8; + gemm<4,8>(m0, m, n0, n); + } else if (m_rem >=8 && n_rem >=4){ + mc = 8; + nc = 4; + gemm<8,4>(m0, m, n0, n); + } else if ((m_rem < 4) && (n_rem >= 8)) { + nc = 8; + switch(m_rem) { + case 1: + mc = 1; + gemm_Mx8<1>(m0, m, n0, n); + break; + case 2: + mc = 2; + gemm_Mx8<2>(m0, m, n0, n); + break; + case 3: + mc = 3; + gemm_Mx8<3>(m0, m, n0, n); + break; + default: + return; + } + } else if (m_rem >= 4 && n_rem >= 4) { + mc = 4; + nc = 4; + gemm_small<4, 4>(m0, m, n0, n); + } else if ((m_rem > 4) && (n_rem < 4)) { + mc = 4; + switch(n_rem) { + case 1: + nc = 1; + gemm_small<4, 1>(m0, m, n0, n); + break; + case 2: + nc = 2; + gemm_small<4, 2>(m0, m, n0, n); + break; + case 3: + nc = 3; + gemm_small<4, 3>(m0, m, n0, n); + break; + + default: + return; + } + } else { + switch((m_rem << 4) | n_rem) { + case 0x43: + mc = 4; + nc = 3; + gemm_small<4, 3>(m0, m, n0, n); + break; + case 0x42: + mc = 4; + nc = 2; + gemm_small<4, 2>(m0, m, n0, n); + break; + case 0x41: + mc = 4; + nc = 1; + gemm_small<4, 1>(m0, m, n0, n); + break; + case 0x34: + mc = 3; + nc = 4; + gemm_small<3, 4>(m0, m, n0, n); + break; + case 0x33: + mc = 3; + nc = 3; + gemm_small<3, 3>(m0, m, n0, n); + break; + case 0x32: + mc = 3; + nc = 2; + gemm_small<3, 2>(m0, m, n0, n); + break; + case 0x31: + mc = 3; + nc = 1; + gemm_small<3, 1>(m0, m, n0, n); + break; + case 0x24: + mc = 2; + nc = 4; + gemm_small<2,4>(m0, m, n0, n); + break; + case 0x23: + mc = 2; + nc = 3; + gemm_small<2, 3>(m0, m, n0, n); + break; + case 0x22: + mc = 2; + nc = 2; + gemm_small<2, 2>(m0, m, n0, n); + break; + case 0x21: + mc = 2; + nc = 1; + gemm_small<2, 1>(m0, m, n0, n); + break; + case 0x14: + mc = 1; + nc = 4; + gemm_small<1, 4>(m0, m, n0, n); + break; + case 0x13: + mc = 1; + nc = 3; + gemm_small<1, 3>(m0, m, n0, n); + break; + case 0x12: + mc = 1; + nc = 2; + gemm_small<1, 2>(m0, m, n0, n); + break; + case 0x11: + mc = 1; + nc = 1; + gemm_small<1, 1>(m0, m, n0, n); + break; + default: + return; + } + } + mp = m0 + (m - m0) / mc * mc; + np = n0 + (n - n0) / nc * nc; + mnpack(mp, m, n0, np); + mnpack(m0, m, np, n); + } + + void KERNEL_4x8(int64_t ii, int64_t jj) { + vec_t vec_A[4], vec_B[8] , vec_C[4]; + acc_t acc_0, acc_1; + __builtin_mma_xxsetaccz(&acc_0); + __builtin_mma_xxsetaccz(&acc_1); + for (int l = 0; l < k; l+=8) { + packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A); + packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B); + for (int x = 0; x < 4; x++) { + __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]); + } + } + SAVE_ACC(&acc_0, ii, jj); + SAVE_ACC(&acc_1, ii, jj+4); + } + + void KERNEL_8x4(int64_t ii, int64_t jj) { + vec_t vec_A[8], vec_B[4] , vec_C[4]; + acc_t acc_0, acc_1; + __builtin_mma_xxsetaccz(&acc_0); + __builtin_mma_xxsetaccz(&acc_1); + for (int l = 0; l < k; l+=8) { + packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A); + packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B); + for (int x = 0; x < 4; x++) { + __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]); + } + } + SAVE_ACC(&acc_0, ii, jj); + SAVE_ACC(&acc_1, ii+4, jj); + } + + + void KERNEL_8x8(int64_t ii, int64_t jj) { + vec_t vec_A[8], vec_B[8], vec_C[4]; + acc_t acc_0, acc_1, acc_2, acc_3; + __builtin_mma_xxsetaccz(&acc_0); + __builtin_mma_xxsetaccz(&acc_1); + __builtin_mma_xxsetaccz(&acc_2); + __builtin_mma_xxsetaccz(&acc_3); + for (int l = 0; l < k; l+=8) { + packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A); + packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B); + for (int x = 0; x < 4; x++) { + __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]); + __builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]); + __builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]); + } + } + + SAVE_ACC(&acc_0, ii, jj); + SAVE_ACC(&acc_1, ii, jj+4); + SAVE_ACC(&acc_2, ii+4, jj); + SAVE_ACC(&acc_3, ii+4, jj+4); + } + + template + void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; + vec_t vec_C[4]; + acc_t acc_0; + __builtin_mma_xxsetaccz(&acc_0); + vec_t vec_A[2], vec_B[2]; + for (int l=0; l + void gemm_Mx8(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int RN = 8; + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; + vec_t vec_C[4]; + acc_t acc_0, acc_1; + __builtin_mma_xxsetaccz(&acc_0); + __builtin_mma_xxsetaccz(&acc_1); + vec_t vec_A[4], vec_B[8]; + for (int l=0; l + inline void kernel(int64_t ii, int64_t jj) { + if constexpr(RM == 4 && RN == 8) { + KERNEL_4x8(ii,jj); + } else if constexpr(RM == 8 && RN == 8) { + KERNEL_8x8(ii,jj); + } else if constexpr(RM == 8 && RN == 4) { + KERNEL_8x4(ii,jj); + } else { + static_assert(false, "RN/RM values not supported"); + } + } + + template + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; + kernel(ii, jj); + } + } + + const TA *const A; + const TB *const B; + TC *C; + const int64_t k; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; + const int ith; + const int nth; +}; + template class tinyBLAS_Q0_PPC { public: @@ -2202,6 +2685,7 @@ class tinyBLAS_PPC { boffset = vec; j = (rows >> 3); if (j > 0) { + do { aoffset1 = aoffset; aoffset2 = aoffset1 + lda; @@ -2875,9 +3359,22 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 (float *)C, ldc}; return tb.matmul(m, n); } +#elif defined(__MMA__) + if ((k % 8)) + return false; + if(Btype == GGML_TYPE_BF16) { + tinyBLAS_BF16_PPC tb{ k, + (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc, + params->ith, params->nth}; + tb.matmul(m, n); + return true; + } #endif return false; } + case GGML_TYPE_F16: { #if defined(__AVX512F__) if (Btype == GGML_TYPE_F16) { diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 6050147be..08facb6d0 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8,19 +8,6 @@ #include -#if defined(_MSC_VER) -// disable "possible loss of data" to avoid hundreds of casts -// we should just be careful :) -#pragma warning(disable: 4244 4267) - -// disable POSIX deprecation warnings -// these functions are never going away, anyway -#pragma warning(disable: 4996) - -// unreachable code because of multiple instances of code after GGML_ABORT -#pragma warning(disable: 4702) -#endif - // ggml_compute_forward_dup static void ggml_compute_forward_dup_same_cont( @@ -2704,6 +2691,109 @@ static void ggml_compute_forward_gelu( } } +// ggml_compute_forward_gelu_erf + +static void ggml_compute_forward_gelu_erf_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_gelu_erf_f32(nc, + (float *) ((char *) dst->data + i1*( dst->nb[1])), + (float *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + GGML_UNUSED(x); + assert(!isnan(x)); + assert(!isinf(x)); + } +#endif + } +} + +static void ggml_compute_forward_gelu_erf_f16( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + assert(ggml_is_contiguous_1(src0)); + assert(ggml_is_contiguous_1(dst)); + assert(ggml_are_same_shape(src0, dst)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_gelu_erf_f16(nc, + (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), + (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + +#ifndef NDEBUG + for (int k = 0; k < nc; k++) { + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float v = GGML_FP16_TO_FP32(x); + GGML_UNUSED(v); + assert(!isnan(v)); + assert(!isinf(v)); + } +#endif + } +} + +static void ggml_compute_forward_gelu_erf( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_gelu_erf_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_gelu_erf_f16(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_gelu_quick static void ggml_compute_forward_gelu_quick_f32( @@ -4222,7 +4312,7 @@ static void ggml_compute_forward_get_rows_f16( GGML_ASSERT(i01 >= 0 && i01 < ne01); - ggml_fp16_to_fp32_row( + ggml_cpu_fp16_to_fp32( (const ggml_fp16_t*) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); } @@ -4263,7 +4353,7 @@ static void ggml_compute_forward_get_rows_bf16( GGML_ASSERT(i01 >= 0 && i01 < ne01); - ggml_bf16_to_fp32_row( + ggml_cpu_bf16_to_fp32( (const ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); } @@ -6064,6 +6154,178 @@ void ggml_compute_forward_conv_transpose_2d( } } +// ggml_compute_forward_conv_2d_dw + +struct ggml_conv_2d_dw_params { + int64_t channels; + int64_t batch; + int64_t src_w; + int64_t src_h; + int64_t dst_w; + int64_t dst_h; + int64_t knl_w; + int64_t knl_h; + int stride_x; + int stride_y; + int pad_x; + int pad_y; + int dilation_x; + int dilation_y; +}; + +static void ggml_compute_forward_conv_2d_dw_cwhn( + const ggml_compute_params * params, + const ggml_tensor * src, + const ggml_tensor * kernel, + ggml_tensor * dst, + const ggml_conv_2d_dw_params & p) { + + const int64_t c = p.channels; + const float * knl_data = (const float *)kernel->data; + + const int64_t rows_total = p.dst_h * p.batch; + const int64_t rows_per_thread = (rows_total + params->nth - 1) / params->nth; + const int64_t row_start = params->ith * rows_per_thread; + const int64_t row_end = MIN(row_start + rows_per_thread, rows_total); + +#ifdef GGML_SIMD + const int64_t pkg_size = GGML_F32_EPR; + const int64_t pkg_count = c / pkg_size; + const int64_t c_pkg_end = pkg_count * pkg_size; +#else + const int64_t c_pkg_end = 0; +#endif + + for (int64_t row = row_start; row < row_end; ++row) { + const int64_t dst_y = row % p.dst_h; + const float * src_data = (const float *)src->data + (row / p.dst_h) * p.src_w * p.src_h * c; + for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) { + float * dst_data = (float *)dst->data + (row * p.dst_w + dst_x) * c; + const int64_t src_y_base = dst_y * p.stride_y - p.pad_y; + const int64_t src_x_base = dst_x * p.stride_x - p.pad_x; + +#ifdef GGML_SIMD + // Vectorized loop + for (int64_t c_i = 0; c_i < c_pkg_end; c_i += pkg_size) { + GGML_F32_VEC sum = GGML_F32_VEC_ZERO; + for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) { + const int64_t src_y = src_y_base + knl_y * p.dilation_y; + if (src_y < 0 || src_y >= p.src_h) { + continue; + } + for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) { + const int64_t src_x = src_x_base + knl_x * p.dilation_x; + if (src_x < 0 || src_x >= p.src_w) { + continue; + } + GGML_F32_VEC k = GGML_F32_VEC_LOAD(knl_data + (knl_y * p.knl_w + knl_x) * c + c_i); + GGML_F32_VEC s = GGML_F32_VEC_LOAD(src_data + (src_y * p.src_w + src_x) * c + c_i); + sum = GGML_F32_VEC_FMA(sum, k, s); + } + } + GGML_F32_VEC_STORE(dst_data + c_i, sum); + } +#endif + // Scalar loop + for (int64_t c_i = c_pkg_end; c_i < c; ++c_i) { + float sum = 0.0f; + for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) { + const int64_t src_y = src_y_base + knl_y * p.dilation_y; + if (src_y < 0 || src_y >= p.src_h) { + continue; + } + for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) { + const int64_t src_x = src_x_base + knl_x * p.dilation_x; + if (src_x < 0 || src_x >= p.src_w) { + continue; + } + sum += knl_data[(knl_y * p.knl_w + knl_x) * c + c_i] + * src_data[(src_y * p.src_w + src_x) * c + c_i]; + } + } + dst_data[c_i] = sum; + } + } + } +} + +static void ggml_compute_forward_conv_2d_dw_whcn( + const ggml_compute_params * params, + const ggml_tensor * src, + const ggml_tensor * kernel, + ggml_tensor * dst, + const ggml_conv_2d_dw_params & p) { + + const int64_t n = p.channels * p.batch; + const int64_t per_thread = (n + params->nth - 1) / params->nth; + const int64_t start = params->ith * per_thread; + const int64_t end = MIN(start + per_thread, n); + + for (int64_t i = start; i < end; ++i) { + const float * knl_data = (const float *)kernel->data + (i % p.channels) * p.knl_w * p.knl_h; + const float * src_data = (const float *)src->data + i * p.src_w * p.src_h; + float * dst_data = (float *)dst->data + i * p.dst_w * p.dst_h; + + for (int64_t dst_y = 0; dst_y < p.dst_h; ++dst_y) { + for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) { + + float sum = 0.0f; + for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) { + const int64_t src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y; + if (src_y < 0 || src_y >= p.src_h) { + continue; + } + for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) { + const int64_t src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x; + if (src_x < 0 || src_x >= p.src_w) { + continue; + } + sum += knl_data[knl_y * p.knl_w + knl_x] + * src_data[src_y * p.src_w + src_x]; + } + } + dst_data[dst_y * p.dst_w + dst_x] = sum; + } + } + } +} + +void ggml_compute_forward_conv_2d_dw( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * kernel = dst->src[0]; + const ggml_tensor * src = dst->src[1]; + ggml_conv_2d_dw_params p; + p.channels = src->ne[2]; + p.batch = src->ne[3]; + p.src_w = src->ne[0]; + p.src_h = src->ne[1]; + p.dst_w = dst->ne[0]; + p.dst_h = dst->ne[1]; + p.knl_w = kernel->ne[0]; + p.knl_h = kernel->ne[1]; + p.stride_x = dst->op_params[0]; + p.stride_y = dst->op_params[1]; + p.pad_x = dst->op_params[2]; + p.pad_y = dst->op_params[3]; + p.dilation_x = dst->op_params[4]; + p.dilation_y = dst->op_params[5]; + + GGML_ASSERT(kernel->ne[3] == p.channels); + GGML_ASSERT(dst->ne[3] == p.batch); + + if (ggml_is_contiguous(src)) { + ggml_compute_forward_conv_2d_dw_whcn(params, src, kernel, dst, p); + } else if (ggml_is_contiguous_channels(src)) { + // kernel should also have channels most contiguous in memory + GGML_ASSERT(kernel->nb[0] >= kernel->nb[2] && kernel->nb[1] >= kernel->nb[0]); + ggml_compute_forward_conv_2d_dw_cwhn(params, src, kernel, dst, p); + } else { + GGML_ABORT("non-contiguous memory layout not supported"); + } +} + // ggml_compute_forward_pool_1d_sk_p0 static void ggml_compute_forward_pool_1d_sk_p0( @@ -7371,39 +7633,83 @@ static void ggml_compute_forward_ssm_scan_f32( const int ir1 = MIN(ir0 + dr, nr); const int ir = ir1 - ir0; - for (int i3 = 0; i3 < n_s; ++i3) { - for (int i2 = 0; i2 < n_t; ++i2) { - const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} - const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} - const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} - const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} - float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} + #ifdef __ARM_FEATURE_SVE + for (int i3 = 0; i3 < n_s; ++i3) { + for (int i2 = 0; i2 < n_t; ++i2) { + const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} + const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} + const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} + const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} + const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} + float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} - // use the output as the source for the next token-wise iterations - if (i2 > 0) { s0 = s; } + // use the output as the source for the next token-wise iterations + if (i2 > 0) { s0 = s; } - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 - float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; - float x_dt = x[i1] * dt_soft_plus; - float sumf = 0.0f; - // d_state - for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - // state = prev_state * dA + dB * x - float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); - // y = rowwise_dotprod(state, C) - sumf += state * C[i0]; - s[i] = state; + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; + float x_dt = x[i1] * dt_soft_plus; + svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt); + svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus); + svfloat32_t r1_vector = GGML_F32_VEC_ZERO; + + for (int64_t k = 0; k < nc; k += svcntw()) { + svfloat32_t vA = GGML_F32_VEC_LOAD(&A[i1*nc + k]); + svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k]); + svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k]); + svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[i1*nc + k]); + + svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA); + t1 = exp_ps_sve(svptrue_b32(), t1); + svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB); + + vs0 = GGML_F32_VEC_FMA(vs0, t1, t2); + r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector); + + GGML_F32_VEC_STORE(&s[i1*nc + k], vs0); + } + y[i1] = GGML_F32xt_REDUCE_ONE(r1_vector); } - y[i1] = sumf; } } - } + #else + for (int i3 = 0; i3 < n_s; ++i3) { + for (int i2 = 0; i2 < n_t; ++i2) { + const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} + const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} + const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} + const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} + const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} + float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} + float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} + + // use the output as the source for the next token-wise iterations + if (i2 > 0) { s0 = s; } + + // d_inner + for (int i1 = 0; i1 < ir; ++i1) { + // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 + float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; + float x_dt = x[i1] * dt_soft_plus; + float sumf = 0.0f; + // d_state + for (int i0 = 0; i0 < nc; ++i0) { + int i = i0 + i1*nc; + // state = prev_state * dA + dB * x + float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[i0]; + s[i] = state; + } + y[i1] = sumf; + } + } + } + #endif } void ggml_compute_forward_ssm_scan( @@ -7590,6 +7896,10 @@ void ggml_compute_forward_unary( { ggml_compute_forward_gelu(params, dst); } break; + case GGML_UNARY_OP_GELU_ERF: + { + ggml_compute_forward_gelu_erf(params, dst); + } break; case GGML_UNARY_OP_GELU_QUICK: { ggml_compute_forward_gelu_quick(params, dst); @@ -7804,6 +8114,14 @@ static void ggml_compute_forward_rwkv_wkv6_f32( #define GGML_F32X_MUL GGML_F32x16_MUL #define GGML_F32X_FMA GGML_F32x16_FMA #define WKV_VECTOR_SIZE 16 + #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__) + #define GGML_F32X GGML_F32xt + #define GGML_F32X_SET1 GGML_F32xt_SET1 + #define GGML_F32X_LOAD GGML_F32xt_LOAD + #define GGML_F32X_STORE GGML_F32xt_STORE + #define GGML_F32X_MUL GGML_F32xt_MUL + #define GGML_F32X_FMA GGML_F32xt_FMA + #define WKV_VECTOR_SIZE 8 #elif defined(__ARM_NEON) && defined(__aarch64__) #define GGML_F32X GGML_F32x4 #define GGML_F32X_SET1 GGML_F32x4_SET1 @@ -7815,7 +8133,13 @@ static void ggml_compute_forward_rwkv_wkv6_f32( #endif #ifdef WKV_VECTOR_SIZE - const int64_t vec_count = head_size / WKV_VECTOR_SIZE; + int wkv_vector_size; + #if defined(__ARM_FEATURE_SVE) + wkv_vector_size = svcntw(); + #else + wkv_vector_size = WKV_VECTOR_SIZE; + #endif + const int64_t vec_count = head_size / wkv_vector_size; for (int64_t t = 0; t < T; t++) { size_t t_offset = t * t_stride; @@ -7845,7 +8169,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32( GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val); for (int64_t j = 0; j < vec_count; j++) { - size_t base_j = j * WKV_VECTOR_SIZE; + size_t base_j = j * wkv_vector_size; size_t t_h_j_offset = t_h_offset + base_j; size_t h_2d_i_j_offset = h_2d_i_offset + base_j; @@ -7870,7 +8194,7 @@ static void ggml_compute_forward_rwkv_wkv6_f32( } // Handle remaining elements, this will not be used. - for (int64_t j = vec_count * WKV_VECTOR_SIZE; j < head_size; j++) { + for (int64_t j = vec_count * wkv_vector_size; j < head_size; j++) { size_t t_h_j_offset = t_h_offset + j; size_t h_2d_i_j_offset = h_2d_i_offset + j; float v_val = v[t_h_j_offset]; @@ -8006,6 +8330,14 @@ static void ggml_compute_forward_gla_f32( #define GGML_F32X_MUL GGML_F32x16_MUL #define GGML_F32X_FMA GGML_F32x16_FMA #define GLA_VECTOR_SIZE 16 + #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__) + #define GGML_F32X GGML_F32xt + #define GGML_F32X_SET1 GGML_F32xt_SET1 + #define GGML_F32X_LOAD GGML_F32xt_LOAD + #define GGML_F32X_STORE GGML_F32xt_STORE + #define GGML_F32X_MUL GGML_F32xt_MUL + #define GGML_F32X_FMA GGML_F32xt_FMA + #define GLA_VECTOR_SIZE 8 #elif defined(__ARM_NEON) && defined(__aarch64__) #define GGML_F32X GGML_F32x4 #define GGML_F32X_SET1 GGML_F32x4_SET1 @@ -8017,7 +8349,13 @@ static void ggml_compute_forward_gla_f32( #endif #ifdef GLA_VECTOR_SIZE - const int64_t vec_count = head_size / GLA_VECTOR_SIZE; + int gla_vector_size; + #if defined(__ARM_FEATURE_SVE) + gla_vector_size = svcntw(); + #else + gla_vector_size = GLA_VECTOR_SIZE; + #endif + const int64_t vec_count = head_size / gla_vector_size; for (int64_t t = 0; t < T; t++) { size_t t_offset = t * t_stride; @@ -8044,7 +8382,7 @@ static void ggml_compute_forward_gla_f32( GGML_F32X g_vec = GGML_F32X_SET1(g_val); for (int64_t j = 0; j < vec_count; j++) { - size_t base_j = j * GLA_VECTOR_SIZE; + size_t base_j = j * gla_vector_size; size_t t_h_j_offset = t_h_offset + base_j; size_t h_2d_i_j_offset = h_2d_i_offset + base_j; @@ -8068,7 +8406,7 @@ static void ggml_compute_forward_gla_f32( } // Handle remaining elements, this will not be used. - for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) { + for (int64_t j = vec_count * gla_vector_size; j < head_size; j++) { size_t t_h_j_offset = t_h_offset + j; size_t h_2d_i_j_offset = h_2d_i_offset + j; float v_val = v[t_h_j_offset]; @@ -8177,83 +8515,126 @@ static void ggml_compute_forward_rwkv_wkv7_f32( int64_t h_stride_2d = head_size * head_size; #if defined(GGML_SIMD) - for (int64_t t = 0; t < T; t++) { - int64_t t_offset = t * t_stride; - int64_t state_offset = head_size * C * (t / (T / n_seqs)); - float * state_cur = state + state_offset; - float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset; + #if defined(__ARM_FEATURE_SVE) + // scalar Route to scalar implementation //TODO: Write SVE code + for (int64_t t = 0; t < T; t++) { + int64_t t_offset = t * t_stride; + int64_t state_offset = head_size * C * (t / (T / n_seqs)); + float * state_cur = state + state_offset; + float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset; - for (int64_t h = h_start; h < h_end; h++) { - int64_t h_offset = h * h_stride; - int64_t t_h_offset = t_offset + h_offset; - int64_t h_2d_offset = h * h_stride_2d; + for (int64_t h = h_start; h < h_end; h++) { + int64_t h_offset = h * h_stride; + int64_t t_h_offset = t_offset + h_offset; + int64_t h_2d_offset = h * h_stride_2d; - for (int64_t ii = 0; ii < head_size; ii++) { - int64_t t_h_i_offset = t_h_offset + ii; - int64_t h_2d_i_offset = h_2d_offset + ii * h_stride; + for (int64_t i = 0; i < head_size; i++) { + int64_t t_h_i_offset = t_h_offset + i; + int64_t h_2d_i_offset = h_2d_offset + i * h_stride; - GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]); + float v_val = v[t_h_i_offset]; - float sa = 0; - { - GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; - GGML_F32_VEC ax[GGML_F32_ARR]; - GGML_F32_VEC ay[GGML_F32_ARR]; - for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) { - for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) { - ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]); - ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]); - sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]); - } + float sa = 0, result = 0; + for (int64_t j = 0; j < head_size; j++) { + sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j]; } - GGML_F32_VEC_REDUCE(sa, sum); - } - GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa); + for (int64_t j = 0; j < head_size; j++) { + int64_t t_h_j_offset = t_h_offset + j; + int64_t h_2d_i_j_offset = h_2d_i_offset + j; - int64_t j = 0; - GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; - for (; j < head_size; j += GGML_F32_STEP) { - for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) { - int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR; - int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR; - - GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]); - GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]); - GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]); - GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]); - - k_vec = GGML_F32_VEC_MUL(v_vec, k_vec); - - GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]); - // kv + s * decay + sa * b - state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec); - state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec); - GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec); - - result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec); + float r_val = r[t_h_j_offset]; + float w_val = w[t_h_j_offset]; + float k_val = k[t_h_j_offset]; + float b_val = b[t_h_j_offset]; + float kv_val = v_val * k_val; + float prev_state_val = state_prev[h_2d_i_j_offset]; + state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val; + result += state_cur[h_2d_i_j_offset] * r_val; } - } - GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec); - - // There shouldn't be left-overs though. - for (; j < head_size; j++) { - int64_t t_h_j_offset = t_h_offset + j; - int64_t h_2d_i_j_offset = h_2d_i_offset + j; - - float r_val = r[t_h_j_offset]; - float w_val = w[t_h_j_offset]; - float k_val = k[t_h_j_offset]; - float b_val = b[t_h_j_offset]; - float kv_val = v[t_h_i_offset] * k_val; - - float prev_state_val = state_prev[h_2d_i_j_offset]; - state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val; - dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val; + dst_data[t_h_i_offset] = result; } } } - } + #else + for (int64_t t = 0; t < T; t++) { + int64_t t_offset = t * t_stride; + int64_t state_offset = head_size * C * (t / (T / n_seqs)); + float * state_cur = state + state_offset; + float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset; + + for (int64_t h = h_start; h < h_end; h++) { + int64_t h_offset = h * h_stride; + int64_t t_h_offset = t_offset + h_offset; + int64_t h_2d_offset = h * h_stride_2d; + + for (int64_t ii = 0; ii < head_size; ii++) { + int64_t t_h_i_offset = t_h_offset + ii; + int64_t h_2d_i_offset = h_2d_offset + ii * h_stride; + + GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]); + + float sa = 0; + { + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) { + for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) { + ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]); + ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]); + sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]); + } + } + GGML_F32_VEC_REDUCE(sa, sum); + } + + GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa); + + int64_t j = 0; + GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; + for (; j < head_size; j += GGML_F32_STEP) { + for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) { + int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR; + int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR; + + GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]); + GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]); + GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]); + GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]); + + k_vec = GGML_F32_VEC_MUL(v_vec, k_vec); + + GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]); + // kv + s * decay + sa * b + state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec); + state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec); + GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec); + + result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec); + } + } + GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec); + + // There shouldn't be left-overs though. + for (; j < head_size; j++) { + int64_t t_h_j_offset = t_h_offset + j; + int64_t h_2d_i_j_offset = h_2d_i_offset + j; + + float r_val = r[t_h_j_offset]; + float w_val = w[t_h_j_offset]; + float k_val = k[t_h_j_offset]; + float b_val = b[t_h_j_offset]; + float kv_val = v[t_h_i_offset] * k_val; + + float prev_state_val = state_prev[h_2d_i_j_offset]; + state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val; + dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val; + } + } + } + } + #endif #else for (int64_t t = 0; t < T; t++) { int64_t t_offset = t * t_stride; diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 410a37204..dc081b9e6 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -65,6 +65,7 @@ void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * p void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_conv_2d_dw(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_pool_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_pool_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_pool_2d_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c new file mode 100644 index 000000000..d2e705f28 --- /dev/null +++ b/ggml/src/ggml-cpu/quants.c @@ -0,0 +1,1157 @@ +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" + +#include "ggml-cpu-impl.h" +#include "ggml-quants.h" +#include "quants.h" + +#include "arch-fallback.h" + +#include +#include +#include +#include // for qsort +#include // for GGML_ASSERT + +#define GROUP_MAX_EPS 1e-15f +#define GROUP_MAX_EPS_IQ3_XXS 1e-8f +#define GROUP_MAX_EPS_IQ2_S 1e-8f +#define GROUP_MAX_EPS_IQ1_M 1e-7f +#define GROUP_MAX_EPS_IQ1_S 1e-12f + +#define UNUSED GGML_UNUSED + +void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + quantize_row_q4_0_ref(x, y, k); +} + +void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + quantize_row_q4_1_ref(x, y, k); +} + +void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + quantize_row_q5_0_ref(x, y, k); +} + +void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + quantize_row_q5_1_ref(x, y, k); +} + +void quantize_row_q8_0_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + quantize_row_q8_0_ref(x, y, k); +} + +void quantize_row_q8_1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + quantize_row_q8_1_ref(x, y, k); +} + +// +// 2-6 bit quantization in super-blocks +// + +//========================- 2-bit (de)-quantization + +void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + quantize_row_q2_K_ref(x, vy, k); +} + +//========================= 3-bit (de)-quantization + +void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + quantize_row_q3_K_ref(x, vy, k); +} + +// ====================== 4-bit (de)-quantization + +void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % QK_K == 0); + block_q4_K * GGML_RESTRICT y = vy; + quantize_row_q4_K_ref(x, y, k); +} + +// ====================== 5-bit (de)-quantization + +void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % QK_K == 0); + block_q5_K * GGML_RESTRICT y = vy; + quantize_row_q5_K_ref(x, y, k); +} + +// ====================== 6-bit (de)-quantization + +void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % QK_K == 0); + block_q6_K * GGML_RESTRICT y = vy; + quantize_row_q6_K_ref(x, y, k); +} + +// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs) + +void quantize_row_tq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % QK_K == 0); + block_tq1_0 * GGML_RESTRICT y = vy; + quantize_row_tq1_0_ref(x, y, k); +} + +void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(k % QK_K == 0); + block_tq2_0 * GGML_RESTRICT y = vy; + quantize_row_tq2_0_ref(x, y, k); +} + +//===================================== Q8_K ============================================== + +void quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + quantize_row_q8_K_ref(x, y, k); +} + +//===================================== Dot products ================================= + +void ggml_vec_dot_q4_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + int ib = 0; + float sumf = 0; + + for (; ib < nb; ++ib) { + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[ib].qs[j] & 0x0F) - 8; + const int v1 = (x[ib].qs[j] >> 4) - 8; + + sumi0 += (v0 * y[ib].qs[j]); + sumi1 += (v1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += sumi*GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d); + } + + *s = sumf; +} + +// TODO: add WASM SIMD +void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_1 * GGML_RESTRICT x = vx; + const block_q8_1 * GGML_RESTRICT y = vy; + + int ib = 0; + float sumf = 0; + + for (; ib < nb; ++ib) { + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[ib].qs[j] & 0x0F); + const int v1 = (x[ib].qs[j] >> 4); + + sumi0 += (v0 * y[ib].qs[j]); + sumi1 += (v1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); + } + + *s = sumf; +} + +void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + int ib = 0; + float sumf = 0; + + assert(n % qk == 0); + assert(qk == QK5_0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + for (; ib < nb; ++ib) { + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); + + const int32_t x0 = (int8_t)(((x[ib].qs[j] & 0x0F) | xh_0) - 16); + const int32_t x1 = (int8_t)(((x[ib].qs[j] >> 4) | xh_1) - 16); + + sumi0 += (x0 * y[ib].qs[j]); + sumi1 += (x1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)) * sumi; + } + + *s = sumf; +} + +void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; + + int ib = 0; + float sumf = 0; + + assert(n % qk == 0); + assert(qk == QK5_1); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_1 * GGML_RESTRICT x = vx; + const block_q8_1 * GGML_RESTRICT y = vy; + + for (; ib < nb; ++ib) { + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + int sumi0 = 0; + int sumi1 = 0; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; + + const int32_t x0 = (x[ib].qs[j] & 0xF) | xh_0; + const int32_t x1 = (x[ib].qs[j] >> 4) | xh_1; + + sumi0 += (x0 * y[ib].qs[j]); + sumi1 += (x1 * y[ib].qs[j + qk/2]); + } + + int sumi = sumi0 + sumi1; + sumf += (GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d))*sumi + GGML_FP16_TO_FP32(x[ib].m)*GGML_FP16_TO_FP32(y[ib].s); + } + + *s = sumf; +} + +void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q8_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + int ib = 0; + float sumf = 0; + + for (; ib < nb; ++ib) { + int sumi = 0; + + for (int j = 0; j < qk; j++) { + sumi += x[ib].qs[j]*y[ib].qs[j]; + } + + sumf += sumi*(GGML_FP16_TO_FP32(x[ib].d)*GGML_FP16_TO_FP32(y[ib].d)); + } + + *s = sumf; +} + +void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq1_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243}; + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + int sum = 0; + + for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) { + for (size_t l = 0; l < 5; ++l) { + for (size_t m = 0; m < 32; ++m) { + uint8_t q = x[i].qs[j + m] * pow3[l]; + uint16_t xi = ((uint16_t) q * 3) >> 8; + sum += (xi - 1) * y[i].qs[j*5 + l*32 + m]; + } + } + } + for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) { + for (size_t l = 0; l < 5; ++l) { + for (size_t m = 0; m < 16; ++m) { + uint8_t q = x[i].qs[j + m] * pow3[l]; + uint16_t xi = ((uint16_t) q * 3) >> 8; + sum += (xi - 1) * y[i].qs[j*5 + l*16 + m]; + } + } + } + + for (size_t l = 0; l < 4; ++l) { + for (size_t j = 0; j < sizeof(x->qh); ++j) { + uint8_t q = x[i].qh[j] * pow3[l]; + uint16_t xi = ((uint16_t) q * 3) >> 8; + sum += (xi - 1) * y[i].qs[sizeof(x->qs)*5 + l*sizeof(x->qh) + j]; + } + } + + sumf += (float) sum * (GGML_FP16_TO_FP32(x[i].d) * y[i].d); + } + + *s = sumf; +} + +void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq2_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + int32_t sumi = 0; + + for (size_t j = 0; j < sizeof(x->qs); j += 32) { + for (size_t l = 0; l < 4; ++l) { + for (size_t k = 0; k < 32; ++k) { + sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].qs[j + k] >> (l*2)) & 3) - 1); + } + } + } + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + sumf += (float) sumi * d; + } + + *s = sumf; +} + +void ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q2_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + int summs = 0; + for (int j = 0; j < 16; ++j) { + summs += y[i].bsums[j] * (sc[j] >> 4); + } + + const float dall = y[i].d * GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); + + int isum = 0; + int is = 0; + int d; + for (int k = 0; k < QK_K/128; ++k) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + d = sc[is++] & 0xF; + int isuml = 0; + for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + d = sc[is++] & 0xF; + isuml = 0; + for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + shift += 2; + q8 += 32; + } + q2 += 32; + } + sumf += dall * isum - dmin * summs; + } + *s = sumf; +} + +void ggml_vec_dot_q3_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + // scalar version + // This function is written like this so the compiler can manage to vectorize most of it + // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the + // manually vectorized version above. Every other version I tried would run at least 4 times slower. + // The ideal situation would be if we could just write the code once, and the compiler would + // automatically produce the best possible set of machine instructions, instead of us having to manually + // write vectorized versions for AVX, ARM_NEON, etc. + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + uint32_t auxs[4]; + const int8_t * scales = (const int8_t*)auxs; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT hm = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + q3 += 32; + } + a = aux8; + + memcpy(auxs, x[i].scales, 12); + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +} + +void ggml_vec_dot_q4_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + a += 32; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + a += 32; q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +} + +void ggml_vec_dot_q5_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q5_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const uint8_t * GGML_RESTRICT hm = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +} + +void ggml_vec_dot_q6_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q4 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * GGML_RESTRICT a = aux8; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + a += 128; + q4 += 64; + qh += 32; + } + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/16; ++j) { + int scale = x[i].scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +} + +void ggml_vec_dot_iq2_xxs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + uint32_t aux32[2]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * GGML_RESTRICT q2 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(aux32, q2, 2*sizeof(uint32_t)); + q2 += 4; + const uint32_t ls = 2*(aux32[1] >> 28) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); + const uint8_t signs = ksigns_iq2xs[(aux32[1] >> 7*l) & 127]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls; + } + sumf += d * bsum; + } + *s = 0.125f * sumf; +} + +void ggml_vec_dot_iq2_xs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * GGML_RESTRICT q2 = x[i].qs; + const uint8_t * GGML_RESTRICT sc = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + const uint16_t ls1 = 2*(sc[ib32] & 0xf) + 1; + const uint16_t ls2 = 2*(sc[ib32] >> 4) + 1; + int32_t sumi = 0; + for (int l = 0; l < 2; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); + const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls1; + sumi = 0; + for (int l = 2; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[l] & 511)); + const uint8_t signs = ksigns_iq2xs[q2[l] >> 9]; + for (int j = 0; j < 8; ++j) { + sumi += grid[j] * q8[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += sumi * ls2; + q2 += 4; + } + sumf += d * bsum; + } + *s = 0.125f * sumf; +} + +void ggml_vec_dot_iq2_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0; + for (int i = 0; i < nb; i++) { + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint8_t * signs = qs + QK_K/8; + + int bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + int ls1 = 1 + 2*(x[i].scales[ib32] & 0xf); + int ls2 = 1 + 2*(x[i].scales[ib32] >> 4); + int sumi1 = 0, sumi2 = 0; + for (int l = 0; l < 2; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); + for (int j = 0; j < 8; ++j) { + sumi1 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + for (int l = 2; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (qs[l] | (qh[ib32] << (8-2*l) & 0x300))); + for (int j = 0; j < 8; ++j) { + sumi2 += q8[j] * grid[j] * (signs[l] & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + } + bsum += ls1 * sumi1 + ls2 * sumi2; + qs += 4; + signs += 4; + } + + sumf += d * bsum; + } + + *s = 0.125f * sumf; +} + +void ggml_vec_dot_iq3_xxs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq3_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + uint32_t aux32; + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT gas = x[i].qs + QK_K/4; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + memcpy(&aux32, gas, sizeof(uint32_t)); gas += sizeof(uint32_t); + const uint32_t ls = 2*(aux32 >> 28) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*l+0]); + const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*l+1]); + const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127]; + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + q3 += 8; + bsum += sumi * ls; + } + sumf += d * bsum; + } + *s = 0.25f * sumf; +} + +void ggml_vec_dot_iq3_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq3_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0.f; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT signs = x[i].signs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + int32_t bsum = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const uint32_t ls1 = 2*(x[i].scales[ib32/2] & 0xf) + 1; + const uint32_t ls2 = 2*(x[i].scales[ib32/2] >> 4) + 1; + int32_t sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+0] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+0] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + qs += 8; + signs += 4; + bsum += sumi * ls1; + sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*l+0] | ((qh[ib32+1] << (8-2*l)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*l+1] | ((qh[ib32+1] << (7-2*l)) & 256))); + for (int j = 0; j < 4; ++j) { + sumi += grid1[j] * q8[j+0] * (signs[l] & kmask_iq2xs[j+0] ? -1 : 1); + sumi += grid2[j] * q8[j+4] * (signs[l] & kmask_iq2xs[j+4] ? -1 : 1); + } + q8 += 8; + } + qs += 8; + signs += 4; + bsum += sumi * ls2; + } + sumf += d * bsum; + } + *s = sumf; +} + +void ggml_vec_dot_iq1_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0; + for (int i = 0; i < nb; i++) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; + + int sumi = 0, sumi1 = 0; + for (int ib = 0; ib < QK_K/32; ++ib) { + const int ls = 2*((qh[ib] >> 12) & 7) + 1; + const int delta = qh[ib] & 0x8000 ? -1 : 1; + int lsum = 0; + for (int l = 0; l < 4; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((qh[ib] >> 3*l) & 7) << 8))); + for (int j = 0; j < 8; ++j) { + lsum += q8[j] * grid[j]; + } + q8 += 8; + } + sumi += ls * lsum; + sumi1 += ls * delta * (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]); + qs += 4; + } + + sumf += GGML_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); + } + + *s = sumf; +} + +void ggml_vec_dot_iq1_m_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_m * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + iq1m_scale_t scale; + + int sum1[2], sum2[2], delta[4]; + + float sumf = 0; + for (int i = 0; i < nb; i++) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + int sumi1 = 0, sumi2 = 0; + for (int ib = 0; ib < QK_K/32; ++ib) { + delta[0] = qh[0] & 0x08 ? -1 : 1; + delta[1] = qh[0] & 0x80 ? -1 : 1; + delta[2] = qh[1] & 0x08 ? -1 : 1; + delta[3] = qh[1] & 0x80 ? -1 : 1; + sum1[0] = sum1[1] = sum2[0] = sum2[1] = 0; + for (int l = 0; l < 4; ++l) { + const int8_t * grid = (const int8_t *)(iq1s_grid + (qs[l] | (((uint16_t)qh[l/2] << (8 - 4*(l%2))) & 0x700))); + int lsum1 = 0, lsum2 = 0; + for (int j = 0; j < 8; ++j) { + lsum1 += q8[j] * grid[j]; + lsum2 += q8[j]; + } + q8 += 8; + sum1[l/2] += lsum1; + sum2[l/2] += lsum2*delta[l]; + } + + const int ls1 = 2*((sc[ib/2] >> (6*(ib%2)+0)) & 0x7) + 1; + const int ls2 = 2*((sc[ib/2] >> (6*(ib%2)+3)) & 0x7) + 1; + + sumi1 += sum1[0] * ls1 + sum1[1] * ls2; + sumi2 += sum2[0] * ls1 + sum2[1] * ls2; + qs += 4; + qh += 2; + } + + sumf += GGML_FP16_TO_FP32(scale.f16) * y[i].d * (sumi1 + IQ1M_DELTA * sumi2); + } + + *s = sumf; +} + +void ggml_vec_dot_iq4_nl_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK4_NL == 0); + static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same"); + + const block_iq4_nl * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + const int nb = n / QK4_NL; + + int ib = 0; + float sumf = 0; + + for (; ib < nb; ++ib) { + const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d); + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < QK4_NL/2; ++j) { + sumi1 += y[ib].qs[j+ 0] * kvalues_iq4nl[x[ib].qs[j] & 0xf]; + sumi2 += y[ib].qs[j+QK4_NL/2] * kvalues_iq4nl[x[ib].qs[j] >> 4]; + } + sumf += d * (sumi1 + sumi2); + } + *s = sumf; +} + +void ggml_vec_dot_iq4_xs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); + + const block_iq4_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0; + for (int ibl = 0; ibl < nb; ++ibl) { + const float d4d8 = GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d; + uint16_t h = x[ibl].scales_h; + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + for (int ib = 0; ib < QK_K/32; ib += 2) { + const uint8_t ls1 = (x[ibl].scales_l[ib/2] & 0xf) | ((h << 4) & 0x30); + const uint8_t ls2 = (x[ibl].scales_l[ib/2] >> 4) | ((h << 2) & 0x30); + h >>= 4; + const float d1 = d4d8*(ls1 - 32); + const float d2 = d4d8*(ls2 - 32); + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 16; ++j) { + sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; + sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; + } + sumf += d1 * (sumi1 + sumi2); + qs += 16; + q8 += 32; + sumi1 = sumi2 = 0; + for (int j = 0; j < 16; ++j) { + sumi1 += q8[j+ 0] * kvalues_iq4nl[qs[j] & 0xf]; + sumi2 += q8[j+16] * kvalues_iq4nl[qs[j] >> 4]; + } + sumf += d2 * (sumi1 + sumi2); + qs += 16; + q8 += 32; + } + } + *s = sumf; +} + +// ============================ 4-bit non-linear quants + +void quantize_row_iq4_nl(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + assert(k % QK4_NL == 0); + quantize_row_iq4_nl_ref(x, y, k); +} + +void quantize_row_iq4_xs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + quantize_iq4_xs(x, y, 1, k, NULL); +} diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.h b/ggml/src/ggml-cpu/quants.h similarity index 56% rename from ggml/src/ggml-cpu/ggml-cpu-quants.h rename to ggml/src/ggml-cpu/quants.h index e33d9d473..dc4342c87 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.h +++ b/ggml/src/ggml-cpu/quants.h @@ -58,6 +58,32 @@ void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +// Generic implementation +void quantize_row_q8_0_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void quantize_row_q8_1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void ggml_vec_dot_q4_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q3_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q4_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q5_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q6_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq2_xxs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq2_xs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq2_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq3_xxs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq3_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq1_s_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq1_m_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq4_nl_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_iq4_xs_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + #ifdef __cplusplus } #endif diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp new file mode 100644 index 000000000..5c6715d5c --- /dev/null +++ b/ggml/src/ggml-cpu/repack.cpp @@ -0,0 +1,1555 @@ +#define GGML_COMMON_IMPL_CPP +#define GGML_COMMON_DECL_CPP +#include "ggml-common.h" +#include "ggml-backend-impl.h" + +#include "ggml-impl.h" +#include "ggml-cpu.h" +#include "ggml-cpu-impl.h" +#include "traits.h" + +#include "arch-fallback.h" + +#include +#include +#include +#include // for qsort +#include // for GGML_ASSERT + +#include "repack.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Woverlength-strings" +#endif + +#define UNUSED GGML_UNUSED + +static inline int nearest_int(float fval) { + assert(fabsf(fval) <= 4194303.f); + float val = fval + 12582912.f; + int i; memcpy(&i, &val, sizeof(int)); + return (i & 0x007fffff) - 0x00400000; +} + +// Functions to create the interleaved data layout formats + +// interleave 4 block_q4_0s in blocks of blck_size_interleave +// returns an interleaved block_q4_0x4 +// in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks +// first, then interleave quants from 4 block_q4_0s in blocks of blck_size_interleave +// +// - in : an array of block_q4_0 pointers +// - blck_size_interleave : the block_q4_0 quants bytes are interleaved in blocks of +// blck_size_interleave bytes +// - xor_mask : the mask to convert the nibbles in block_q4_0 quants bytes +// from bias offset form to pure sign form (this saves subtract +// operations durin unpacking) +// + +extern "C" { + +void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; + + // scalar + const int blck_size_interleave = 4; + float srcv[4][QK8_0]; + float id[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; + amax = MAX(amax, fabsf(srcv[row_iter][j])); + } + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < QK8_0 * 4; j++) { + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); + + float x0 = srcv[src_id][src_offset] * id[src_id]; + y[i].qs[j] = roundf(x0); + } + } +} + +void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; + + // scalar + const int blck_size_interleave = 8; + float srcv[4][QK8_0]; + float id[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; + amax = MAX(amax, fabsf(srcv[row_iter][j])); + } + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_FP32_TO_FP16(d); + } + + for (int j = 0; j < QK8_0 * 4; j++) { + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); + + float x0 = srcv[src_id][src_offset] * id[src_id]; + y[i].qs[j] = roundf(x0); + } + } +} + +void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK_K == 256); + assert(k % QK_K == 0); + const int nb = k / QK_K; + + block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy; + + // scalar + const int blck_size_interleave = 8; + float srcv[4][QK_K]; + float iscale[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + float max = 0; + + for (int j = 0; j < QK_K; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK_K + j]; + // Update the maximum value of the corresponding super block + if(amax < fabsf(srcv[row_iter][j])) { + amax = fabsf(srcv[row_iter][j]); + max = srcv[row_iter][j]; + } + } + + iscale[row_iter] = amax ? -127.f/max : 0; + + y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0; + } + + for (int j = 0; j < QK_K / 4; j++) { + y[i].bsums[j] = 0; + } + + // Quants values are interleaved in sequence of eight bytes from corresponding super blocks + // Bsums values are interleaved in sequence of four bsums from each super block taken for interleaving + // i.e first four bsums from the first super block, followed by first four bsums from second super block and so on + for (int j = 0; j < QK_K * 4; j++) { + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); + int index = (((j & 31) >> 3) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3); + + float x0 = srcv[src_id][src_offset] * iscale[src_id]; + y[i].qs[j] = nearest_int(x0); + y[i].bsums[index] += y[i].qs[j]; + } + } +} + +} // extern "C" + +template +void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row); + +template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { + assert(nrow == 4); + UNUSED(nrow); + ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row); +} + +template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { + assert(nrow == 4); + UNUSED(nrow); + ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row); +} + +template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { + assert(nrow == 4); + UNUSED(nrow); + ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row); +} + +extern "C" { + +void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + { + float sumf[8]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } + } +} + +void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[8]; + float sum_minf[8]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + sum_minf[j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32; + uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16; + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4); + sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i]); + sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i + 32]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16; + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; + } + } +} + +void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + { + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + } + sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } + } +} + +void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + { + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + +void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} + +void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[4][8]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} + +void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[4][8]; + float sum_minf[4][8]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + sum_minf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32; + uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16; + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4); + sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i]); + sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; + } + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16; + for(int m = 0; m < 4; m++) { + const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6); + for(int j = 0; j < ncols_interleaved; j++) { + sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; + } + } + } + } +} + +void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + { + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])); + } + sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + +} // extern "C" + +static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) { + block_q4_0x4 out; + + for (int i = 0; i < 4; i++) { + out.d[i] = in[i].d; + } + + const int end = QK4_0 * 2 / blck_size_interleave; + + if (blck_size_interleave == 8) { + const uint64_t xor_mask = 0x8888888888888888ULL; + for (int i = 0; i < end; ++i) { + int src_id = i % 4; + int src_offset = (i / 4) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elems; + // Using memcpy to avoid unaligned memory accesses + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); + elems ^= xor_mask; + memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); + } + } else if (blck_size_interleave == 4) { + const uint32_t xor_mask = 0x88888888; + for (int i = 0; i < end; ++i) { + int src_id = i % 4; + int src_offset = (i / 4) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint32_t elems; + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint32_t)); + elems ^= xor_mask; + memcpy(&out.qs[dst_offset], &elems, sizeof(uint32_t)); + } + } else { + GGML_ASSERT(false); + } + + return out; +} + +// interleave 8 block_q4_0s in blocks of blck_size_interleave +// returns an interleaved block_q4_0x8 +// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks +// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave +static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave) { + block_q4_0x8 out; + + for (int i = 0; i < 8; i++) { + out.d[i] = in[i].d; + } + + const int end = QK4_0 * 4 / blck_size_interleave; + const uint64_t xor_mask = 0x8888888888888888ULL; + + for (int i = 0; i < end; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elems; + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); + elems ^= xor_mask; + memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); + } + + return out; +} + +static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_interleave) { + block_q4_Kx8 out; + //Delta(scale) and dmin values of the eight Q4_K structures are copied onto the output interleaved structure + for (int i = 0; i < 8; i++) { + out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + } + + for (int i = 0; i < 8; i++) { + out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; + } + + const int end = QK_K * 4 / blck_size_interleave; + + // Interleave Q4_K quants by taking 8 bytes at a time + for (int i = 0; i < end; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elems; + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); + memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); + } + + // The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K + // Currently the Q4_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value) + // The output Q4_Kx8 structure has 96 bytes + // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q4_K structure + // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q4_K structures + uint8_t s[8], m[8]; + + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 8; j++) { + s[j] = in[j].scales[i] & 63; + m[j] = in[j].scales[i + 4] & 63; + } + + out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2); + out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2); + out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2); + out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2); + out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2); + out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2); + out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2); + out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2); + out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4); + out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4); + out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4); + out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4); + + } + + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 8; j++) { + s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15); + m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4); + } + + out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2); + out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2); + out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2); + out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2); + out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2); + out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2); + out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2); + out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2); + out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4); + out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4); + out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4); + out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4); + + } + + return out; +} + +static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 4 || interleave_block == 8); + constexpr int nrows_interleaved = 4; + + block_q4_0x4 * dst = (block_q4_0x4 *)t->data; + const block_q4_0 * src = (const block_q4_0 *)data; + block_q4_0 dst_tmp[4]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x4(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} +static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_K); + GGML_ASSERT(interleave_block == 8); + constexpr int nrows_interleaved = 8; + + block_q4_Kx8 * dst = (block_q4_Kx8*)t->data; + const block_q4_K * src = (const block_q4_K*) data; + block_q4_K dst_tmp[8]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_Kx8(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 8); + constexpr int nrows_interleaved = 8; + + block_q4_0x8 * dst = (block_q4_0x8*)t->data; + const block_q4_0 * src = (const block_q4_0*) data; + block_q4_0 dst_tmp[8]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x8(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) { + block_iq4_nlx4 out; + + for (int i = 0; i < 4; i++) { + out.d[i] = in[i].d; + } + + const int end = QK4_NL * 2 / blck_size_interleave; + + // TODO: this branch seems wrong + //if (blck_size_interleave == 8) { + // for (int i = 0; i < end; ++i) { + // int src_id = i % 4; + // int src_offset = (i / 4) * blck_size_interleave; + // int dst_offset = i * blck_size_interleave; + + // // Using memcpy to avoid unaligned memory accesses + // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); + // } + //} else + if (blck_size_interleave == 4) { + for (int i = 0; i < end; ++i) { + int src_id = i % 4; + int src_offset = (i / 4) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t)); + } + } else { + GGML_ASSERT(false); + } + + return out; +} + +static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL); + //GGML_ASSERT(interleave_block == 4 || interleave_block == 8); + GGML_ASSERT(interleave_block == 4); + + block_iq4_nlx4 * dst = (block_iq4_nlx4 *)t->data; + const block_iq4_nl * src = (const block_iq4_nl *)data; + block_iq4_nl dst_tmp[4]; + int nrow = ggml_nrows(t); + int nrows_interleaved = 4; + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_iq4_nlx4(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +namespace ggml::cpu::repack { +// repack +template +int repack(struct ggml_tensor *, const void *, size_t); + +// TODO: generalise. +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_4_bl(t, 4, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_4_bl(t, 8, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size); +} + +// TODO: needs to be revisited +//template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { +// return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size); +//} + +// gemv +template +void gemv(int, float *, size_t, const void *, const void *, int, int); + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); +} + +// gemm +template +void gemm(int, float *, size_t, const void *, const void *, int, int); + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); +} + +class tensor_traits_base : public ggml::cpu::tensor_traits { + public: + virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0; +}; + +template class tensor_traits : public tensor_traits_base { + + bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { + // not realy a GGML_TYPE_Q8_0 but same size. + switch (op->op) { + case GGML_OP_MUL_MAT: + size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1])); + return true; + case GGML_OP_MUL_MAT_ID: + size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1])); + size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc. + size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2]; + return true; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; + } + + bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + forward_mul_mat(params, op); + return true; + case GGML_OP_MUL_MAT_ID: + forward_mul_mat_id(params, op); + return true; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; + } + + void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_n_dims(op->src[0]) == 2); + // GGML_ASSERT(ggml_n_dims(op->src[1]) == 2); + + char * wdata = static_cast(params->wdata); + const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10); + + assert(params->wsize >= nbw1 * ne11); + + const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float; + + int64_t i11_processed = 0; + for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { + ggml_quantize_mat_t((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10); + } + + i11_processed = ne11 - ne11 % 4; + for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) { + from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10); + } + + ggml_barrier(params->threadpool); + + const void * src1_wdata = params->wdata; + const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10); + int64_t src0_start = (ith * ne01) / nth; + int64_t src0_end = ((ith + 1) * ne01) / nth; + src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start; + src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end; + if (src0_start >= src0_end) { + return; + } + + // If there are more than three rows in src1, use gemm; otherwise, use gemv. + if (ne11 > 3) { + gemm(ne00, + (float *) ((char *) dst->data) + src0_start, ne01, + (const char *) src0->data + src0_start * nb01, + (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start); + } + for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) { + gemv(ne00, + (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01, + (const char *) src0->data + src0_start * nb01, + (const char *) src1_wdata + (src1_col_stride * iter), 1, + src0_end - src0_start); + } + } + + void forward_mul_mat_id(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + const ggml_tensor * ids = op->src[2]; + ggml_tensor * dst = op; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float; + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(src0->type)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ne03 == 1); + GGML_ASSERT(ne13 == 1); + GGML_ASSERT(ne3 == 1); + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + // row groups + const int n_ids = ids->ne[0]; // n_expert_used + const int n_as = ne02; // n_expert + + const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10); + const size_t nbw2 = nbw1*ne11; + const size_t nbw3 = nbw2*ne12; + + struct mmid_row_mapping { + int32_t i1; + int32_t i2; + }; + + GGML_ASSERT(params->wsize >= (GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) + + n_as * ne12 * sizeof(mmid_row_mapping))); + + auto * wdata = (char *) params->wdata; + auto * wdata_src1_end = (char *) wdata + GGML_PAD(nbw3, sizeof(int64_t)); + auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] + + struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12] + + // src1: float32 => param type + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = ith; i11 < ne11; i11 += nth) { + from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11), + (void *) (wdata + i12 * nbw2 + i11 * nbw1), + ne10); + } + } + +#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ne12 + (i1)] + + if (ith == 0) { + // initialize matrix_row_counts + memset(matrix_row_counts, 0, n_as * sizeof(int64_t)); + + // group rows by src0 matrix + for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { + for (int32_t id = 0; id < n_ids; ++id) { + const int32_t i02 = + *(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); + + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 }; + matrix_row_counts[i02] += 1; + } + } + } + + ggml_barrier(params->threadpool); + + // compute each matrix multiplication in sequence + for (int cur_a = 0; cur_a < n_as; ++cur_a) { + const int64_t cne1 = matrix_row_counts[cur_a]; + + if (cne1 == 0) { + continue; + } + + const auto * src0_cur = (const char *) src0->data + cur_a*nb02; + + //const int64_t nr0 = ne01; // src0 rows + const int64_t nr1 = cne1; // src1 rows + + int64_t src0_cur_start = (ith * ne01) / nth; + int64_t src0_cur_end = ((ith + 1) * ne01) / nth; + + src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start; + src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end; + + if (src0_cur_start >= src0_cur_end) { + return; + } + + for (int ir1 = 0; ir1 < nr1; ir1++) { + struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1); + + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row + + const auto * src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2); + + gemv(ne00, + (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01, + src0_cur + src0_cur_start * nb01, + src1_col, 1, src0_cur_end - src0_cur_start); + } + } +#undef MMID_MATRIX_ROW + } + + int repack(struct ggml_tensor * t, const void * data, size_t data_size) override { + GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type), + (int) NB_COLS, (int) INTER_SIZE); + return ggml::cpu::repack::repack(t, data, data_size); + } +}; + +// instance for Q4 +static const tensor_traits q4_0_4x4_q8_0; +static const tensor_traits q4_0_4x8_q8_0; +static const tensor_traits q4_0_8x8_q8_0; +static const tensor_traits q4_K_8x8_q8_K; + +// instance for IQ4 +static const tensor_traits iq4_nl_4x4_q8_0; + +} // namespace ggml::cpu::repack + +static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(const struct ggml_tensor * cur) { + if (cur->type == GGML_TYPE_Q4_0) { + if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) { + if (cur->ne[1] % 8 == 0) { + return &ggml::cpu::repack::q4_0_8x8_q8_0; + } + } + if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { + if (cur->ne[1] % 4 == 0) { + return &ggml::cpu::repack::q4_0_4x8_q8_0; + } + } + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + if (cur->ne[1] % 4 == 0) { + return &ggml::cpu::repack::q4_0_4x4_q8_0; + } + } + } else if (cur->type == GGML_TYPE_Q4_K) { + if (ggml_cpu_has_avx2()) { + if (cur->ne[1] % 8 == 0) { + return &ggml::cpu::repack::q4_K_8x8_q8_K; + } + } + } else if (cur->type == GGML_TYPE_IQ4_NL) { + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + if (cur->ne[1] % 4 == 0) { + return &ggml::cpu::repack::iq4_nl_4x4_q8_0; + } + } + } + + return nullptr; +} + +static enum ggml_status ggml_backend_cpu_repack_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { + tensor->extra = (void *) const_cast(ggml_repack_get_optimal_repack_type(tensor)); + + GGML_UNUSED(buffer); + return GGML_STATUS_SUCCESS; +} + +static void ggml_backend_cpu_repack_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, + const void * data, size_t offset, size_t size) { + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + + auto tensor_traits = (ggml::cpu::repack::tensor_traits_base *) tensor->extra; + auto OK = tensor_traits->repack(tensor, data, size); + + GGML_ASSERT(OK == 0); + GGML_UNUSED(buffer); +} + +static const char * ggml_backend_cpu_repack_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return "CPU_REPACK"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_t ggml_backend_cpu_repack_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); + + if (buffer == nullptr) { + return nullptr; + } + + buffer->buft = buft; + buffer->iface.init_tensor = ggml_backend_cpu_repack_buffer_init_tensor; + buffer->iface.set_tensor = ggml_backend_cpu_repack_buffer_set_tensor; + buffer->iface.get_tensor = nullptr; + buffer->iface.cpy_tensor = nullptr; + return buffer; +} + +static size_t ggml_backend_cpu_repack_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return TENSOR_ALIGNMENT; + + GGML_UNUSED(buft); +} + +namespace ggml::cpu::repack { +class extra_buffer_type : ggml::cpu::extra_buffer_type { + bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { + if ( op->op == GGML_OP_MUL_MAT && + op->src[0]->buffer && + (ggml_n_dims(op->src[0]) == 2) && + op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type() && + ggml_repack_get_optimal_repack_type(op->src[0]) + ) { + if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + return false; + } + if (op->src[1]->type == GGML_TYPE_F32) { + return true; + } + //if (op->src[1]->type == GGML_TYPE_Q8_0) { + // return true; + //} + // may be possible if Q8_0 packed... + } else if (op->op == GGML_OP_MUL_MAT_ID + && op->src[0]->buffer + && (ggml_n_dims(op->src[0]) == 3) + && op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type() + && ggml_repack_get_optimal_repack_type(op->src[0]) + ) { + if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + return false; + } + if (op->src[1]->type == GGML_TYPE_F32) { + return true; + } + //if (op->src[1]->type == GGML_TYPE_Q8_0) { + // return true; + //} + } + return false; + } + + ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { + if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) { + if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type()) { + return (ggml::cpu::tensor_traits *) op->src[0]->extra; + } + } + return nullptr; + } +}; +} // namespace ggml::cpu::repack + +ggml_backend_buffer_type_t ggml_backend_cpu_repack_buffer_type(void) { + static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_repack = { + /* .iface = */ { + /* .get_name = */ ggml_backend_cpu_repack_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_cpu_repack_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_repack_buffer_type_get_alignment, + /* .get_max_size = */ nullptr, // defaults to SIZE_MAX + /* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes + /* .is_host = */ nullptr, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), + /* .context = */ new ggml::cpu::repack::extra_buffer_type(), + }; + + return &ggml_backend_cpu_buffer_type_repack; +} diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h new file mode 100644 index 000000000..4421e5f8e --- /dev/null +++ b/ggml/src/ggml-cpu/repack.h @@ -0,0 +1,98 @@ +#pragma once + +#define GGML_COMMON_DECL_CPP +#include "ggml-common.h" + +#include "traits.h" +#include "ggml.h" + +// GGML internal header + +ggml_backend_buffer_type_t ggml_backend_cpu_repack_buffer_type(void); + +template constexpr int QK_0() { + if constexpr (K == 4) { + return QK4_0; + } + if constexpr (K == 8) { + return QK8_0; + } + return -1; +} + +template struct block { + ggml_half d[N]; // deltas for N qK_0 blocks + int8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_0 blocks +}; + +// control size +static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding"); +static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding"); +static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding"); +static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding"); + +using block_q4_0x4 = block<4, 4>; +using block_q4_0x8 = block<4, 8>; +using block_q8_0x4 = block<8, 4>; +using block_q8_0x8 = block<8, 8>; + +struct block_q4_Kx8 { + ggml_half d[8]; // super-block scale for quantized scales + ggml_half dmin[8]; // super-block scale for quantized mins + uint8_t scales[96]; // scales and mins, quantized with 6 bits + uint8_t qs[1024]; // 4--bit quants +}; + +static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding"); + +struct block_q8_Kx4 { + float d[4]; // delta + int8_t qs[QK_K * 4]; // quants + int16_t bsums[QK_K / 4]; // sum of quants in groups of 16 +}; + +static_assert(sizeof(block_q8_Kx4) == sizeof(float) * 4 + QK_K * 4 + (QK_K / 4) * sizeof(int16_t), "wrong q8_K block size/padding"); + +struct block_iq4_nlx4 { + ggml_half d[4]; // deltas for 4 iq4_nl blocks + uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks +}; + +static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding"); + +#if defined(__cplusplus) +extern "C" { +#endif + +void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); + +// Native implementations +void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); + +#if defined(__cplusplus) +} // extern "C" +#endif diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index 04d10cec2..e42364c59 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -17,7 +17,123 @@ // number of elements to fit in a single register // -#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) +#if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_FMA) + +#define GGML_SIMD + +// F32 SVE +#define GGML_F32_EPR 8 +#define DEFAULT_PG svptrue_b32() + +#define GGML_F32xt svfloat32_t +#define GGML_F32xt_ZERO svdup_n_f32(0.0f) +#define GGML_F32xt_SET1(x) svdup_n_f32(x) +#define GGML_F32xt_LOAD_IMPL(pg, a, ...) svld1_f32(pg, a) +#define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b) +#define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, a, b, c) +#define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b) +#define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_MUL_IMPL(pg, a, b) svmul_f32_m(pg, a, b) +#define GGML_F32xt_MUL(...) GGML_F32xt_MUL_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_REDUCE_ONE_IMPL(pg, a) svaddv(pg, a) +#define GGML_F32xt_REDUCE_ONE(...) GGML_F32xt_REDUCE_ONE_IMPL(DEFAULT_PG, __VA_ARGS__) +#define GGML_F32xt_REDUCE_IMPL(pg, res, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8) \ +{ \ + sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum2); \ + sum3 = svadd_f32_m(DEFAULT_PG, sum3, sum4); \ + sum5 = svadd_f32_m(DEFAULT_PG, sum5, sum6); \ + sum7 = svadd_f32_m(DEFAULT_PG, sum7, sum8); \ + sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum3); \ + sum5 = svadd_f32_m(DEFAULT_PG, sum5, sum7); \ + sum1 = svadd_f32_m(DEFAULT_PG, sum1, sum5); \ + (res) = (ggml_float) GGML_F32xt_REDUCE_ONE(sum1); \ +} +#define GGML_F32xt_REDUCE(...) GGML_F32xt_REDUCE_IMPL(DEFAULT_PG, __VA_ARGS__) + +#define GGML_F32_VEC GGML_F32xt +#define GGML_F32_VEC_ZERO GGML_F32xt_ZERO +#define GGML_F32_VEC_SET1 GGML_F32xt_SET1 +#define GGML_F32_VEC_LOAD GGML_F32xt_LOAD +#define GGML_F32_VEC_STORE GGML_F32xt_STORE +#define GGML_F32_VEC_FMA GGML_F32xt_FMA +#define GGML_F32_VEC_ADD GGML_F32xt_ADD +#define GGML_F32_VEC_MUL GGML_F32xt_MUL +#define GGML_F32_VEC_REDUCE GGML_F32xt_REDUCE + +// F16 NEON + +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + #define GGML_F16_STEP 32 + #define GGML_F16_EPR 8 + + #define GGML_F16x8 float16x8_t + #define GGML_F16x8_ZERO vdupq_n_f16(0.0f) + #define GGML_F16x8_SET1(x) vdupq_n_f16(x) + #define GGML_F16x8_LOAD(x) vld1q_f16((const __fp16 *)(x)) + #define GGML_F16x8_STORE vst1q_f16 + #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c) + #define GGML_F16x8_ADD vaddq_f16 + #define GGML_F16x8_MUL vmulq_f16 + #define GGML_F16x8_REDUCE(res, x) \ + do { \ + int offset = GGML_F16_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f16((x)[i], (x)[offset+i]); \ + } \ + const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 ((x)[0])); \ + const float32x4_t t1 = vcvt_f32_f16(vget_high_f16((x)[0])); \ + (res) = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \ + } while (0) + + #define GGML_F16_VEC GGML_F16x8 + #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO + #define GGML_F16_VEC_SET1 GGML_F16x8_SET1 + #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((__fp16 *)(p), (r)[i]) + #define GGML_F16_VEC_FMA GGML_F16x8_FMA + #define GGML_F16_VEC_ADD GGML_F16x8_ADD + #define GGML_F16_VEC_MUL GGML_F16x8_MUL + #define GGML_F16_VEC_REDUCE GGML_F16x8_REDUCE +#else + // if FP16 vector arithmetic is not supported, we use FP32 instead + // and take advantage of the vcvt_ functions to convert to/from FP16 + + #define GGML_F16_STEP 16 + #define GGML_F16_EPR 4 + + #define GGML_F32Cx4 float32x4_t + #define GGML_F32Cx4_ZERO vdupq_n_f32(0.0f) + #define GGML_F32Cx4_SET1(x) vdupq_n_f32(x) + #define GGML_F32Cx4_LOAD(x) vcvt_f32_f16(vld1_f16((const __fp16 *)(x))) + #define GGML_F32Cx4_STORE(x, y) vst1_f16(x, vcvt_f16_f32(y)) + #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c) + #define GGML_F32Cx4_ADD vaddq_f32 + #define GGML_F32Cx4_MUL vmulq_f32 + #define GGML_F32Cx4_REDUCE GGML_F32x4_REDUCE + + #define GGML_F16_VEC GGML_F32Cx4 + #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO + #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 + #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((__fp16 *)(p), r[i]) + #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA + #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD + #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL + #define GGML_F16_VEC_REDUCE GGML_F32Cx4_REDUCE +#endif + +#elif defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) #define GGML_SIMD @@ -341,7 +457,7 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) { #define GGML_F32_EPR 4 #define GGML_F32x4 vector float -#define GGML_F32x4_ZERO 0.0f +#define GGML_F32x4_ZERO {0.0f} #define GGML_F32x4_SET1 vec_splats #define GGML_F32x4_LOAD(p) vec_xl(0, p) #define GGML_F32x4_STORE(p, r) vec_xst(r, 0, p) @@ -828,10 +944,8 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) { for (int i = 0; i < offset; ++i) { \ x[i] = vec_add(x[i], x[offset + i]); \ } \ - res = vec_extract(x[0], 0) + \ - vec_extract(x[0], 1) + \ - vec_extract(x[0], 2) + \ - vec_extract(x[0], 3); \ + float32x4_t tmp = x[0] + vec_reve(x[0]); \ + res = tmp[0] + tmp[1]; \ } #define GGML_F32_VEC GGML_F32x4 diff --git a/ggml/src/ggml-cpu/ggml-cpu-traits.cpp b/ggml/src/ggml-cpu/traits.cpp similarity index 97% rename from ggml/src/ggml-cpu/ggml-cpu-traits.cpp rename to ggml/src/ggml-cpu/traits.cpp index 62a0712da..139fa5964 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +++ b/ggml/src/ggml-cpu/traits.cpp @@ -1,4 +1,4 @@ -#include "ggml-cpu-traits.h" +#include "traits.h" #include "ggml-backend-impl.h" #include "ggml-backend.h" diff --git a/ggml/src/ggml-cpu/ggml-cpu-traits.h b/ggml/src/ggml-cpu/traits.h similarity index 100% rename from ggml/src/ggml-cpu/ggml-cpu-traits.h rename to ggml/src/ggml-cpu/traits.h diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index dfe2218e3..f7614568e 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -2,12 +2,6 @@ #include -#if defined(_MSC_VER) -// disable "possible loss of data" to avoid hundreds of casts -// we should just be careful :) -#pragma warning(disable: 4244 4267) -#endif - // precomputed gelu table for f16 (128 KB) ggml_fp16_t ggml_table_gelu_f16[1 << 16]; @@ -23,29 +17,98 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G #if defined(GGML_SIMD) float sumf = 0.0f; - const int np = (n & ~(GGML_F32_STEP - 1)); - GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; + #if defined(__ARM_FEATURE_SVE) + const int sve_register_length = ggml_cpu_get_sve_cnt() * 8; + const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16 + const int ggml_f32_step = 8 * ggml_f32_epr; // choose 8 SVE registers - GGML_F32_VEC ax[GGML_F32_ARR]; - GGML_F32_VEC ay[GGML_F32_ARR]; + const int np = (n & ~(ggml_f32_step - 1)); + svfloat32_t sum1 = svdup_n_f32(0.0f); + svfloat32_t sum2 = svdup_n_f32(0.0f); + svfloat32_t sum3 = svdup_n_f32(0.0f); + svfloat32_t sum4 = svdup_n_f32(0.0f); + svfloat32_t sum5 = svdup_n_f32(0.0f); + svfloat32_t sum6 = svdup_n_f32(0.0f); + svfloat32_t sum7 = svdup_n_f32(0.0f); + svfloat32_t sum8 = svdup_n_f32(0.0f); + svfloat32_t ax1,ax2,ax3,ax4,ax5,ax6,ax7,ax8; + svfloat32_t ay1,ay2,ay3,ay4,ay5,ay6,ay7,ay8; + for (int i = 0; i < np; i += ggml_f32_step) { + ax1 = GGML_F32_VEC_LOAD(x + i); + ay1 = GGML_F32_VEC_LOAD(y + i); + sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1); - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr); + ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr); + sum2 = GGML_F32_VEC_FMA(ax2, ay2, sum2); - sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); + ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr); + ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr); + sum3 = GGML_F32_VEC_FMA(ax3, ay3, sum3); + + ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr); + ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr); + sum4 = GGML_F32_VEC_FMA(ax4, ay4, sum4); + + ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr); + ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr); + sum5 = GGML_F32_VEC_FMA(ax5, ay5, sum5); + + ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr); + ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr); + sum6 = GGML_F32_VEC_FMA(ax6, ay6, sum6); + + ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr); + ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr); + sum7 = GGML_F32_VEC_FMA(ax7, ay7, sum7); + + ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr); + ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr); + sum8 = GGML_F32_VEC_FMA(ax8, ay8, sum8); } - } + // leftovers + // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop + const int np2 = (n & ~(ggml_f32_epr - 1)); + for (int i = np; i < np2; i += ggml_f32_epr) { + ax1 = GGML_F32_VEC_LOAD(x + i); + ay1 = GGML_F32_VEC_LOAD(y + i); + sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1); + } + // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only + if (np2 < n) { + svbool_t pg = svwhilelt_b32(np2, n); + ax1 = svld1_f32(pg, x + np2); + ay1 = svld1_f32(pg, y + np2); + sum1 = svmad_f32_m(pg, ax1, ay1, sum1); + } + // reduce sum1,sum2 to sum1 + GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8); + #else + const int np = (n & ~(GGML_F32_STEP - 1)); - // reduce sum0..sum3 to sum0 - GGML_F32_VEC_REDUCE(sumf, sum); + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; - // leftovers - for (int i = np; i < n; ++i) { - sumf += x[i]*y[i]; - } + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + + sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); + } + } + + // reduce sum0..sum3 to sum0 + GGML_F32_VEC_REDUCE(sumf, sum); + + // leftovers + for (int i = np; i < n; ++i) { + sumf += x[i]*y[i]; + } + #endif #else // scalar ggml_float sumf = 0.0; diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 23cbb3051..09dbade21 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -5,6 +5,7 @@ #include "ggml-impl.h" #include "simd-mappings.h" #include "ggml.h" +#include "ggml-cpu.h" #if defined(GGML_USE_ACCELERATE) #include @@ -148,27 +149,108 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const float * GGML_RESTRICT x, const float v) { #if defined(GGML_SIMD) - const int np = (n & ~(GGML_F32_STEP - 1)); + #if defined(__ARM_FEATURE_SVE) - GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + const int sve_register_length = ggml_cpu_get_sve_cnt() * 8; + const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16 + const int ggml_f32_step = 8 * ggml_f32_epr; // choose 8 SVE registers + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); - GGML_F32_VEC ax[GGML_F32_ARR]; - GGML_F32_VEC ay[GGML_F32_ARR]; + const int np = (n & ~(ggml_f32_step - 1)); + svfloat32_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; + svfloat32_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; + for (int i = 0; i < np; i += ggml_f32_step) { - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx); + ax1 = GGML_F32_VEC_LOAD(x + i); + ay1 = GGML_F32_VEC_LOAD(y + i); + ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1); - GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + GGML_F32_VEC_STORE(y + i, ay1); + + ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr); + ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr); + ay2 = GGML_F32_VEC_FMA(ax2, vx, ay2); + + GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2); + + ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr); + ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr); + ay3 = GGML_F32_VEC_FMA(ax3, vx, ay3); + + GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3); + + ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr); + ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr); + ay4 = GGML_F32_VEC_FMA(ax4, vx, ay4); + + GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4); + + ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr); + ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr); + ay5 = GGML_F32_VEC_FMA(ax5, vx, ay5); + + GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5); + + ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr); + ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr); + ay6 = GGML_F32_VEC_FMA(ax6, vx, ay6); + + GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6); + + ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr); + ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr); + ay7 = GGML_F32_VEC_FMA(ax7, vx, ay7); + + GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7); + + ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr); + ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr); + ay8 = GGML_F32_VEC_FMA(ax8, vx, ay8); + + GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8); } - } + // leftovers + // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop + const int np2 = (n & ~(ggml_f32_epr - 1)); + for (int i = np; i < np2; i += ggml_f32_epr) { + ax1 = GGML_F32_VEC_LOAD(x + i); + ay1 = GGML_F32_VEC_LOAD(y + i); + ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1); - // leftovers - for (int i = np; i < n; ++i) { - y[i] += x[i]*v; - } + GGML_F32_VEC_STORE(y + i, ay1); + } + // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only + if (np2 < n) { + svbool_t pg =svwhilelt_b32(np2, n); + ax1 = svld1_f32(pg, x + np2); + ay1 = svld1_f32(pg, y + np2); + ay1 = svmad_f32_m(pg, ax1, vx, ay1); + + svst1_f32(pg, y + np2, ay1); + } + #else + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx); + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] += x[i]*v; + } + #endif #else // scalar for (int i = 0; i < n; ++i) { @@ -220,36 +302,45 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int } #if defined(GGML_SIMD) - const int np = (n & ~(GGML_F32_STEP - 1)); - - GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL]; - - for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { - vx[k] = GGML_F32_VEC_SET1(v[k][0]); - } - - GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR]; - GGML_F32_VEC ay[GGML_F32_ARR]; - - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - - for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { - ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]); + #if defined(__ARM_FEATURE_SVE) + // scalar Route to scalar implementation //TODO: Write SVE code + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + for (int i = 0; i < n; ++i) { + y[i] += x[k][i]*v[k][0]; } - - GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); } - } + #else + const int np = (n & ~(GGML_F32_STEP - 1)); - // leftovers - for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { - for (int i = np; i < n; ++i) { - y[i] += x[k][i]*v[k][0]; + GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL]; + + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + vx[k] = GGML_F32_VEC_SET1(v[k][0]); } - } + + GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]); + } + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { + for (int i = np; i < n; ++i) { + y[i] += x[k][i]*v[k][0]; + } + } + #endif #else // scalar for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { @@ -265,25 +356,53 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { #if defined(GGML_USE_ACCELERATE) vDSP_vsmul(y, 1, &v, y, 1, n); #elif defined(GGML_SIMD) - const int np = (n & ~(GGML_F32_STEP - 1)); + #if defined(__ARM_FEATURE_SVE) + const int sve_register_length = ggml_cpu_get_sve_cnt() * 8; + const int ggml_f32_epr = sve_register_length / 32;//8;//svcntw(); // SVE128:4, SVE256:8, SVE512:16 + const int ggml_f32_step = 2 * ggml_f32_epr; - GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + const int np = (n & ~(ggml_f32_step - 1)); + svfloat32_t ay1; + svfloat32_t ay2; + for (int i = 0; i < np; i += ggml_f32_step) { + ay1 = GGML_F32_VEC_LOAD(y + i); + ay1 = GGML_F32_VEC_MUL(ay1, vx); + GGML_F32_VEC_STORE(y + i, ay1); - GGML_F32_VEC ay[GGML_F32_ARR]; - - for (int i = 0; i < np; i += GGML_F32_STEP) { - for (int j = 0; j < GGML_F32_ARR; j++) { - ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_MUL(ay[j], vx); - - GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr); + ay2 = GGML_F32_VEC_MUL(ay2, vx); + GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2); } - } + // leftovers + // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only + if (np < n) { + svbool_t pg = svwhilelt_b32(np, n); + ay1 = svld1_f32(pg, y + np); + ay1 = svmul_f32_m(pg, ay1, vx); + svst1_f32(pg, y + np, ay1); + } + #else + const int np = (n & ~(GGML_F32_STEP - 1)); - // leftovers - for (int i = np; i < n; ++i) { - y[i] *= v; - } + GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); + + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_MUL(ay[j], vx); + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] *= v; + } + #endif #else // scalar for (int i = 0; i < n; ++i) { @@ -428,6 +547,7 @@ inline static void ggml_vec_exp_f16 (const int n, ggml_fp16_t * y, const ggml_fp static const float GELU_COEF_A = 0.044715f; static const float GELU_QUICK_COEF = -1.702f; static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; +static const float SQRT_2_INV = 0.70710678118654752440084436210484f; inline static float ggml_gelu_f32(float x) { return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); @@ -440,6 +560,14 @@ inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp } } +inline static void ggml_vec_gelu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + for (int i = 0; i < n; ++i) { + float xi = GGML_FP16_TO_FP32(x[i]); + float res = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV)); + y[i] = GGML_FP32_TO_FP16(res); + } +} + #ifdef GGML_GELU_FP16 inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { uint16_t t; @@ -463,6 +591,13 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { } #endif +inline static void ggml_vec_gelu_erf_f32(const int n, float * y, const float * x) { + for (int i = 0; i < n; ++i) { + float xi = x[i]; + y[i] = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV)); + } +} + inline static float ggml_gelu_quick_f32(float x) { return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x))); } @@ -512,6 +647,42 @@ inline static ggml_fp16_t ggml_silu_f16(ggml_fp16_t x) { #error "ref: https://github.com/ggml-org/llama.cpp/pull/7154#issuecomment-2143844461" #endif +/* Below function was borrowed from the GitHub repository: +https://github.com/openvinotoolkit/openvino/blob/master/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp */ +#if defined(__ARM_FEATURE_SVE) && defined(__aarch64__) + inline static svfloat32_t exp_ps_sve(svbool_t pg, svfloat32_t src) { + // Constants + const svfloat32_t log2_e = svdup_n_f32(1.4426950409f); + const svfloat32_t ln2 = svdup_n_f32(0.6931473921f); + const svfloat32_t half_ln2_sq = svdup_n_f32(0.2413862043f); + const svuint32_t not_mask17 = svdup_n_u32(~((1u << 17) - 1)); + const svfloat32_t one = svdup_n_f32(1.0f); + const svfloat32_t inactive1 = svdup_n_f32(0.0f); + const svint32_t inactive2 = svdup_n_s32(0); + + // Algorithm starts here + svfloat32_t t0 = svmul_f32_m(pg, src, log2_e); // y = x * log2(e) + svfloat32_t t1 = svrintm_f32_m(inactive1, pg, t0); // rount to int (float) + svint32_t t2 = svcvt_s32_f32_m(inactive2, pg, t1); // n + + t1 = svsub_f32_m(pg, t0, t1); // a = y - floor(y) + t1 = svadd_f32_m(pg, t1, one); // b = a + 1 + + svuint32_t t3 = svlsr_n_u32_m(pg, svreinterpret_u32_f32(t1), 17); // v = b >> 17 (u32) + svfloat32_t t4 = svexpa_f32(t3); // c = fexpa(v) + t4 = svscale_f32_m(pg, t4, t2); // fexpa(v) * 2^(n) + + // and_(t2.d, t1.d, not_mask17.d) + svfloat32_t t5 = svreinterpret_f32_u32(svand_u32_m(pg, svreinterpret_u32_f32(t1), not_mask17)); + t5 = svsub_f32_m(pg, t1, t5); // z + t0 = svmla_f32_m(pg, ln2, t5, half_ln2_sq); // ln2 + half_ln2_sq * z + t0 = svmla_f32_m(pg, one, t5, t0); // 1 + (ln2 * z) + (half_ln2_sq * z * z) + t0 = svmul_f32_m(pg, t0, t4); // Final result + + return t0; + } +#endif + #if defined(__ARM_NEON) && defined(__aarch64__) // adapted from arm limited optimized routine diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 8623214c7..c9ff4aa32 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -12,12 +12,30 @@ if (CUDAToolkit_FOUND) # 61 == Pascal, __dp4a instruction (per-byte integer dot product) # 70 == V100, FP16 tensor cores # 75 == Turing, int8 tensor cores + # 80 == Ampere, asynchronous data loading, faster tensor core instructions + # 86 == RTX 3000, needs CUDA v11.1 + # 89 == RTX 4000, needs CUDA v11.8 + # + # XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run + # XX-real == compile CUDA code as device code for this specific architecture + # no suffix == compile as both PTX and device code + # + # The default behavior for a non-native is to build virtual architectures as needed to cover all features needed + # for best performance and to also build real architectures for the most commonly used GPUs. if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6" AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.24") set(CMAKE_CUDA_ARCHITECTURES "native") elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) - set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75;80") + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8") + set(CMAKE_CUDA_ARCHITECTURES "60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;89-real") + else() + set(CMAKE_CUDA_ARCHITECTURES "60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real") + endif() else() - set(CMAKE_CUDA_ARCHITECTURES "50;61;70;75;80") + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8") + set(CMAKE_CUDA_ARCHITECTURES "50-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real;89-real") + else() + set(CMAKE_CUDA_ARCHITECTURES "50-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-real") + endif() endif() endif() message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") @@ -100,7 +118,7 @@ if (CUDAToolkit_FOUND) set(CUDA_CXX_FLAGS "") - set(CUDA_FLAGS -use_fast_math) + set(CUDA_FLAGS -use_fast_math -extended-lambda) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8") # Options are: @@ -133,6 +151,7 @@ if (CUDAToolkit_FOUND) COMMAND ${NVCC_CMD} -Xcompiler "-dumpfullversion -dumpversion" OUTPUT_VARIABLE CUDA_CCVER ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE ) else() if (CUDA_CCFULLVER MATCHES Apple) @@ -143,7 +162,7 @@ if (CUDAToolkit_FOUND) string(REGEX REPLACE "^.* version ([0-9.]*).*$" "\\1" CUDA_CCVER ${CUDA_CCFULLVER}) endif() - message("-- CUDA host compiler is ${CUDA_CCID} ${CUDA_CCVER}") + message(STATUS "CUDA host compiler is ${CUDA_CCID} ${CUDA_CCVER}") ggml_get_flags(${CUDA_CCID} ${CUDA_CCVER}) list(APPEND CUDA_CXX_FLAGS ${CXX_FLAGS} ${GF_CXX_FLAGS}) # This is passed to -Xcompiler later diff --git a/ggml/src/ggml-cuda/acc.cu b/ggml/src/ggml-cuda/acc.cu index 96bfe1c9d..e084607c0 100644 --- a/ggml/src/ggml-cuda/acc.cu +++ b/ggml/src/ggml-cuda/acc.cu @@ -1,47 +1,61 @@ #include "acc.cuh" -static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne, - const int ne10, const int ne11, const int ne12, - const int nb1, const int nb2, int offset) { - const int i = blockDim.x * blockIdx.x + threadIdx.x; +static __global__ void acc_f32(const float * x, const float * y, float * dst, const int64_t ne, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const int64_t s11, const int64_t s12, const int64_t s13, const int64_t offset) { + const int64_t i = blockDim.x * blockIdx.x + threadIdx.x; + if (i >= ne) { return; } - int src1_idx = i - offset; - int oz = src1_idx / nb2; - int oy = (src1_idx - (oz * nb2)) / nb1; - int ox = src1_idx % nb1; - if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) { - dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11]; - } else { - dst[i] = x[i]; + + int64_t src1_idx = i - offset; + + int64_t tmp = src1_idx; + const int64_t i13 = tmp / s13; + tmp -= i13 * s13; + const int64_t i12 = tmp / s12; + tmp -= i12 * s12; + const int64_t i11 = tmp / s11; + tmp -= i11 * s11; + const int64_t i10 = tmp; + + float val = x[i]; + if (src1_idx >= 0 && i10 < ne10 && i11 < ne11 && i12 < ne12 && i13 < ne13) { + val += y[((i13*ne12 + i12) * ne11 + i11) * ne10 + i10]; } + dst[i] = val; } -static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements, - const int ne10, const int ne11, const int ne12, - const int nb1, const int nb2, const int offset, cudaStream_t stream) { - int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE; - acc_f32<<>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset); +static void acc_f32_cuda(const float * x, const float * y, float * dst, const int64_t n_elements, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const int64_t s1, const int64_t s2, const int64_t s3, const int64_t offset, cudaStream_t stream) { + const int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE; + acc_f32<<>>(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset); } void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - const float * src0_d = (const float *)src0->data; - const float * src1_d = (const float *)src1->data; - float * dst_d = (float *)dst->data; + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported - int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 - int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 - // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused - int offset = dst->op_params[3] / 4; // offset in bytes + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(dst->nb[0] == ggml_element_size(dst)); + GGML_ASSERT(ggml_is_contiguously_allocated(dst)); - acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, stream); + const int64_t s1 = dst->op_params[0] / sizeof(float); + const int64_t s2 = dst->op_params[1] / sizeof(float); + const int64_t s3 = dst->op_params[2] / sizeof(float); + const int64_t offset = dst->op_params[3] / sizeof(float); + + acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], s1, s2, s3, offset, stream); } diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 8284a0017..c14a12f54 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -78,13 +78,13 @@ // Moore Threads #define GGML_CUDA_MUSA_ARCH_IS_QY1 (__MUSA_ARCH__ <= 210) -#define GGML_CUDA_CC_QY1 (GGML_MUSA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000 -#define GGML_CUDA_CC_QY2 (GGML_MUSA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000 -#define GGML_CUDA_CC_NG (GGML_MUSA_CC_OFFSET_MTHREADS + 0x310) // TBD +#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000 +#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000 +#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD #define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD) #define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2) -#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NEXT) +#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG) #define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG) #ifdef __CUDA_ARCH_LIST__ @@ -130,10 +130,6 @@ static int ggml_cuda_highest_compiled_arch(const int arch) { #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses -#if defined(_MSC_VER) -#pragma warning(disable: 4244 4267) // possible loss of data -#endif - #define GGML_CUDA_MAX_STREAMS 8 [[noreturn]] @@ -172,7 +168,7 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in #define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str) -#if !defined(GGML_USE_HIP) +#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) static const char * cu_get_error_str(CUresult err) { const char * err_str; cuGetErrorString(err, &err_str); @@ -211,9 +207,9 @@ typedef float2 dfloat2; #define FP16_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA -#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4)) +#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4))) #define FP16_MMA_AVAILABLE -#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4)) +#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4))) #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING #define NEW_MMA_AVAILABLE @@ -266,11 +262,11 @@ static bool cp_async_available(const int cc) { } static constexpr __device__ int ggml_cuda_get_physical_warp_size() { -#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) - return __AMDGCN_WAVEFRONT_SIZE; +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(__GFX9__) || defined(__GFX8__)) + return 64; #else return 32; -#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(__GFX9__) || defined(__GFX8__)) } [[noreturn]] @@ -300,6 +296,25 @@ static __device__ void no_device_code( #define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.") #endif // __CUDA_ARCH__ +// The compiler is always able to unroll loops if they contain continue expressions. +// In such cases loop unrolling can still be achieved via recursion: +template +struct ggml_cuda_unroll { + template + __device__ void operator()(const Func & f, Args... args) const { + f(n - 1, args...); + ggml_cuda_unroll{}(f, args...); + } +}; + +template <> +struct ggml_cuda_unroll<1> { + template + __device__ void operator()(const Func & f, Args... args) const { + f(0, args...); + } +}; + template static __device__ __forceinline__ int warp_reduce_sum(int x) { #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE @@ -451,9 +466,6 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) } -// TODO: move to ggml-common.h -static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; - typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v); static __device__ __forceinline__ float get_alibi_slope( @@ -620,6 +632,7 @@ struct ggml_cuda_device_info { int nsm; // number of streaming multiprocessors size_t smpb; // max. shared memory per block size_t smpbo; // max. shared memory per block (with opt-in) + bool integrated; // Device is integrated as opposed to discrete bool vmm; // virtual memory support size_t vmm_granularity; // granularity of virtual memory size_t total_vram; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index a224ec0e1..c6dec4276 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -1,6 +1,8 @@ #include "convert.cuh" #include "dequantize.cuh" +#include + #define CUDA_Q8_0_NE_ALIGN 2048 template @@ -570,30 +572,46 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t } template -static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) { - const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; +static __global__ void convert_unary( + const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t s01, const int64_t s02, const int64_t s03) { + const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; - if (i >= k) { + if (i00 >= ne00) { return; } + const int64_t i01 = blockIdx.y; + const int64_t i02 = blockIdx.z % ne02; + const int64_t i03 = blockIdx.z / ne02; + const src_t * x = (const src_t *) vx; - y[i] = float(x[i]); + const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00; + const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00; + y[iy] = float(x[ix]); } template -static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - convert_unary<<>>(vx, y, k); +static void convert_unary_cuda(const void * vx, dst_t * y, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) { + const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03); + convert_unary<<>> + (vx, y, ne00, ne01, ne02, s01, s02, s03); +} + +template +static void convert_unary_cont_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + convert_unary_cuda(vx, y, k, 1, 1, 1, k, k, k, stream); } to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F32: - return convert_unary_cuda; + return convert_unary_cont_cuda; case GGML_TYPE_F16: - return convert_unary_cuda; + return convert_unary_cont_cuda; default: return nullptr; } @@ -643,9 +661,9 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { case GGML_TYPE_IQ3_S: return dequantize_row_iq3_s_cuda; case GGML_TYPE_F32: - return convert_unary_cuda; + return convert_unary_cont_cuda; case GGML_TYPE_BF16: - return convert_unary_cuda; + return convert_unary_cont_cuda; default: return nullptr; } @@ -692,7 +710,18 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { case GGML_TYPE_IQ3_S: return dequantize_row_iq3_s_cuda; case GGML_TYPE_F16: - return convert_unary_cuda; + return convert_unary_cont_cuda; + case GGML_TYPE_BF16: + return convert_unary_cont_cuda; + default: + return nullptr; + } +} + +to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_F32: + return convert_unary_cuda; case GGML_TYPE_BF16: return convert_unary_cuda; default: diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index 411a13cf1..b65b98e08 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -3,7 +3,7 @@ #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 template -using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, cudaStream_t stream); +using to_t_cuda_t = void (*)(const void * x, T * y, int64_t k, cudaStream_t stream); typedef to_t_cuda_t to_fp32_cuda_t; typedef to_t_cuda_t to_fp16_cuda_t; @@ -14,3 +14,13 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type); to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type); to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type); + +// TODO more general support for non-contiguous inputs + +template +using to_t_nc_cuda_t = void (*)(const void * x, T * y, + int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, + int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream); + +typedef to_t_nc_cuda_t to_fp16_nc_cuda_t; +to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type); diff --git a/ggml/src/ggml-cuda/cp-async.cuh b/ggml/src/ggml-cuda/cp-async.cuh index ecb659997..63d0c482f 100644 --- a/ggml/src/ggml-cuda/cp-async.cuh +++ b/ggml/src/ggml-cuda/cp-async.cuh @@ -2,6 +2,17 @@ #include "common.cuh" + +static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(void * generic_ptr) { +#ifdef CP_ASYNC_AVAILABLE + return __cvta_generic_to_shared(generic_ptr); +#else + GGML_UNUSED(generic_ptr); + NO_DEVICE_CODE; + return 0; +#endif // CP_ASYNC_AVAILABLE +} + // Copies data from global to shared memory, cg == cache global. // Both the src and dst pointers must be aligned to 16 bit. // Shared memory uses 32 bit addressing, the pointer is passed as unsigned int. diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 4f4faa3e6..2c55d2149 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -1,5 +1,8 @@ #include "cpy.cuh" #include "dequantize.cuh" +#ifdef GGML_USE_MUSA +#include "ggml-musa/mudnn.cuh" +#endif // GGML_USE_MUSA typedef void (*cpy_kernel_t)(const char * cx, char * cdst); @@ -551,7 +554,7 @@ static void ggml_cpy_f16_f16_cuda( (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } -void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) { +void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) { const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); @@ -588,14 +591,23 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg char ** dest_ptrs_d = nullptr; int graph_cpynode_index = -1; #if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) - if(ctx.cuda_graph->use_cpy_indirection) { + if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) { dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d; graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index; } +#else + GGML_UNUSED(disable_indirection_for_this_node); #endif if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); - CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); +#ifdef GGML_USE_MUSA + if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) { + CUDA_CHECK(mudnnMemcpyAsync(ctx, src1, src0)); + } else +#endif // GGML_USE_MUSA + { + CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); + } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { @@ -636,16 +648,19 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_type_name(src0->type), ggml_type_name(src1->type)); } #if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) - if(ctx.cuda_graph->use_cpy_indirection) { + if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) { ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index; } +#else + GGML_UNUSED(disable_indirection_for_this_node); #endif } void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - ggml_cuda_cpy(ctx, src0, dst); + bool disable_indirection = true; + ggml_cuda_cpy(ctx, src0, dst, disable_indirection); } void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { diff --git a/ggml/src/ggml-cuda/cpy.cuh b/ggml/src/ggml-cuda/cpy.cuh index 6bed0564d..0bd3c0c6f 100644 --- a/ggml/src/ggml-cuda/cpy.cuh +++ b/ggml/src/ggml-cuda/cpy.cuh @@ -2,7 +2,7 @@ #define CUDA_CPY_BLOCK_SIZE 64 -void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1); +void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection = false); void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 56121705b..cfab2b5eb 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -516,7 +516,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { nullptr; } -template // D == head size +template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup( float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) { @@ -623,8 +623,8 @@ static __global__ void flash_attn_combine_results( __builtin_assume(tid < D); extern __shared__ float2 meta[]; - if (tid < 2*parallel_blocks) { - ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid]; + for (int i = tid; i < 2*parallel_blocks; i += D) { + ((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i]; } __syncthreads(); @@ -665,23 +665,27 @@ static void on_no_fattn_vec_case(const int D) { fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n"); GGML_ABORT("fatal error"); } else { - fprintf(stderr, "Unsupported KV type combination for head_size 256.\n"); + fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D); fprintf(stderr, "Only f16 is supported.\n"); GGML_ABORT("fatal error"); } } -template +template void launch_fattn( ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE ) { constexpr int ncols = ncols1 * ncols2; + const bool is_mla = DV == 512; // TODO better parameterization + const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; + GGML_ASSERT(V || is_mla); + const ggml_tensor * mask = dst->src[3]; ggml_tensor * KQV = dst; @@ -689,9 +693,13 @@ void launch_fattn( GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(KQV->type == GGML_TYPE_F32); + GGML_ASSERT( Q->nb[0] == ggml_element_size(Q)); + GGML_ASSERT( K->nb[0] == ggml_element_size(K)); + GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V)); + GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && - "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); + "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); @@ -713,12 +721,13 @@ void launch_fattn( size_t nb12 = K->nb[2]; size_t nb13 = K->nb[3]; - const char * V_data = (const char *) V->data; - size_t nb21 = V->nb[1]; - size_t nb22 = V->nb[2]; - size_t nb23 = V->nb[3]; + const char * V_data = V ? (const char *) V->data : nullptr; + size_t nb21 = V ? V->nb[1] : nb11; + size_t nb22 = V ? V->nb[2] : nb12; + size_t nb23 = V ? V->nb[3] : nb13; if (need_f16_K && K->type != GGML_TYPE_F16) { + GGML_ASSERT(ggml_is_contiguously_allocated(K)); K_f16.alloc(ggml_nelements(K)); to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type); to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream); @@ -732,7 +741,8 @@ void launch_fattn( nb13 = nb13*bs*sizeof(half)/ts; } - if (need_f16_V && V->type != GGML_TYPE_F16) { + if (V && need_f16_V && V->type != GGML_TYPE_F16) { + GGML_ASSERT(ggml_is_contiguously_allocated(V)); V_f16.alloc(ggml_nelements(V)); to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); @@ -752,10 +762,13 @@ void launch_fattn( const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; const dim3 block_dim(warp_size, nwarps, 1); + int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. + CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); + dim3 blocks_num; if (stream_k) { // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. - const int max_blocks = 2*nsm; + const int max_blocks = max_blocks_per_sm*nsm; const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks; const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves); @@ -767,14 +780,11 @@ void launch_fattn( blocks_num.y = 1; blocks_num.z = 1; - dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float)); + dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float)); } else { GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0); const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. - int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. - CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); - // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave: parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1); @@ -851,19 +861,19 @@ void launch_fattn( if (stream_k) { if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. - const dim3 block_dim_combine(D, 1, 1); + const dim3 block_dim_combine(DV, 1, 1); const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2}; - flash_attn_stream_k_fixup + flash_attn_stream_k_fixup <<>> ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]); } } else if (parallel_blocks > 1) { - const dim3 block_dim_combine(D, 1, 1); + const dim3 block_dim_combine(DV, 1, 1); const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z); const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2); - flash_attn_combine_results + flash_attn_combine_results <<>> (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks); } diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 04804a15c..e230f6d49 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -13,104 +13,386 @@ typedef tile<16, 16, float> tile_C_KQ_16; typedef tile<16, 4, half2> tile_C_VKQ; typedef tile<16, 8, half2> tile_C_VKQ_16; -template -static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( - const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) { - constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. +// Config options for specific head sizes. +// Should not affect results, only speed/register pressure/shared memory use. +// +// nbatch_fa: number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators. +// nwarps_max: maximum number of warps per CUDA block, up to 8 warps in total can run per SM (given enough shared memory). +// Q_in_reg: whether the Q values should be kept permanently in registers. +// nstages_target: targeted number of pipeline stages for cp_async (if available), 0 means synchronous data loading. +// nbatch_K2: number of K half2 values in direction of DKQ to load in parallel. +// nbatch_V2: number of V half2 values in direction of DV to load in parallel. +// nbatch_combine: number of VKQ half2 values in direction of DV to combine in parallel. - // If cp.async is available, load up to the highest power of 2 in D asynchronously: -#ifdef CP_ASYNC_AVAILABLE - static_assert(D >= 64 && D < 512, "bad D"); - constexpr int k0_sync_start = D/2 < 64 ? 32 : (D/2 < 128 ? 64 : 128); +template +struct fattn_mma_f16_config; - const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV); +template <> +struct fattn_mma_f16_config< 64, 64> { + static constexpr int nbatch_fa = 64; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; - constexpr int preload = 64; - constexpr int h2_per_chunk = 16/sizeof(half2); - constexpr int chunks_per_row = k0_sync_start / h2_per_chunk; - constexpr int stride_i = WARP_SIZE / chunks_per_row; -#pragma unroll - for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) { - const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row); - const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk; - - cp_async_cg_16(tile_KV_32 + (i*D2_padded + k)*sizeof(half2), KV + i*stride_KV + k); + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 32; } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 32; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 32; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 32; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 32; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 32; + } +}; + +template <> +struct fattn_mma_f16_config< 80, 80> { + static constexpr int nbatch_fa = 64; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 40; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 40; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 40; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 40; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 40; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 40; + } +}; + +template <> +struct fattn_mma_f16_config< 96, 96> { + static constexpr int nbatch_fa = 64; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 48; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 48; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 48; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 48; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 48; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 48; + } +}; + +template <> +struct fattn_mma_f16_config<112, 112> { + static constexpr int nbatch_fa = 64; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 56; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 56; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 56; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 56; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 56; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 56; + } +}; + +template <> +struct fattn_mma_f16_config<128, 128> { + static constexpr int nbatch_fa = 64; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 64; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 64; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 64; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 64; + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 64; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 64; + } +}; + +template <> +struct fattn_mma_f16_config<256, 256> { + static constexpr int nbatch_fa = 32; + static constexpr int nwarps_max = 4; + static constexpr bool Q_in_reg = true; + static constexpr int nstages_target = 2; + + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { + return 128; + } + + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { + return 128; + } + + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { + return 128; + } + + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { + return 128; + } + + static int get_nbatch_combine_host(const int cc, const int ncols) { + if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { + return ncols <= 16 ? 128 : 64; + } + return 64; + } + + static constexpr __device__ int get_nbatch_combine_device(int ncols) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING + return ncols <= 16 ? 128 : 64; #else - constexpr int k0_sync_start = 0; -#endif // CP_ASYNC_AVAILABLE - static_assert(k0_sync_start % WARP_SIZE == 0, "bad k0_sync_start"); + GGML_UNUSED(ncols); + return 128; +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING + } +}; + +template <> +struct fattn_mma_f16_config<576, 512> { + static constexpr int nbatch_fa = 32; + static constexpr int nwarps_max = 8; + static constexpr bool Q_in_reg = false; + static constexpr int nstages_target = 1; + + static int get_nbatch_K2_host(const int cc, const int ncols) { + if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { + return ncols <= 16 ? 96 : 160; + } + return ncols <= 16 ? 288 : 160; + } + + static constexpr __device__ int get_nbatch_K2_device(int ncols) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING + return ncols <= 16 ? 96 : 160; +#else + return ncols <= 16 ? 288 : 160; +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING + } + + static int get_nbatch_V2_host(const int cc, const int ncols) { + if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) { + return ncols <= 16 ? 64 : 128; + } + return ncols <= 16 ? 256 : 128; + } + + static constexpr __device__ int get_nbatch_V2_device(int ncols) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING + return ncols <= 16 ? 64 : 128; +#else + return ncols <= 16 ? 256 : 128; +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING + } + + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { + return 128; + } + + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { + return 128; + } +}; + +// ------------------------------------------------------------------------------------------------------------------ + +template +static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) { - // If D is not a power of 2, the rest is loaded synchronously. // K/V data is loaded with decreasing granularity for D for better memory bandwidth. - static_assert(KQ_per_iter % (4*nwarps) == 0, "out of bounds"); -#pragma unroll - for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k); - const int k0_stop = D/2 - (D/2) % (1*stride_k); - const int stride_i = WARP_SIZE / stride_k; + // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. - if (k0_start == k0_stop || k0_stop <= k0_sync_start) { - continue; - } + if (use_cp_async) { + constexpr int preload = 64; + constexpr int h2_per_chunk = 16/sizeof(half2); + const int chunks_per_row = D2 / h2_per_chunk; -#pragma unroll - for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) { - const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV); -#pragma unroll - for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + auto load = [&] __device__ (auto n) { + const int stride_k = WARP_SIZE >> n; + const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); + const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); + const int stride_i = WARP_SIZE / stride_k; - tile_KV[i*D2_padded + k] = KV[i*stride_KV + k]; + if (k0_start == k0_stop) { + return; } - } + +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { + break; + } + +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + cp_async_cg_16(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk); + } + } + }; + ggml_cuda_unroll<5>{}(load); + } else { + static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds"); + auto load = [&] __device__ (const int n) { + const int stride_k = WARP_SIZE >> n; + const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k); + const int k0_stop = D2 - D2 % (1*stride_k); + const int stride_i = WARP_SIZE / stride_k; + + if (k0_start == k0_stop) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { + break; + } + +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + tile_KV[i*stride_tile + k] = KV[i*stride_KV + k]; + } + } + }; + ggml_cuda_unroll<3>{}(load); } } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) { - static_assert(KQ_per_iter == 2*WARP_SIZE || KQ_per_iter == WARP_SIZE, "bad KQ_per_iter"); -#ifdef CP_ASYNC_AVAILABLE - constexpr int preload = KQ_per_iter * sizeof(half); - constexpr int cols_per_warp = 8*WARP_SIZE/KQ_per_iter; - constexpr int stride_j = nwarps * cols_per_warp; + static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter"); - const unsigned int tile_mask_32 = __cvta_generic_to_shared(tile_mask); + if (use_cp_async) { + constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64; + constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa; + constexpr int stride_j = nwarps * cols_per_warp; + + const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask); +#pragma unroll + for (int j0 = 0; j0 < ncols1; j0 += stride_j) { + const int j = j0 + threadIdx.y*cols_per_warp + + (nbatch_fa == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/cols_per_warp)); + + if (j0 + stride_j > ncols1 && j >= ncols1) { + break; + } + + const int i = 4 * (threadIdx.x % (nbatch_fa/8)); + + cp_async_cg_16(tile_mask_32 + j*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i); + } + return; + } + + constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa; + constexpr int stride_j = nwarps * cols_per_warp; #pragma unroll for (int j0 = 0; j0 < ncols1; j0 += stride_j) { - const int j = j0 + threadIdx.y*cols_per_warp + - (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/8)); + const int j = j0 + threadIdx.y*cols_per_warp + (nbatch_fa == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/cols_per_warp)); if (j0 + stride_j > ncols1 && j >= ncols1) { break; } - const int i = 4 * (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x % (WARP_SIZE/4) : threadIdx.x % (WARP_SIZE/8)); + const int i = nbatch_fa == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/cols_per_warp); - cp_async_cg_16(tile_mask_32 + j*(KQ_per_iter*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i); + tile_mask[j*(nbatch_fa/2 + 4) + i] = mask_h2[j*stride_mask + i]; } -#else - constexpr int cols_per_warp = 2*WARP_SIZE/KQ_per_iter; - constexpr int stride_j = nwarps * cols_per_warp; -#pragma unroll - for (int j0 = 0; j0 < ncols1; j0 += stride_j) { - const int j = j0 + threadIdx.y*cols_per_warp + (KQ_per_iter == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/2)); - - if (j0 + stride_j > ncols1 && j >= ncols1) { - break; - } - - const int i = KQ_per_iter == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/2); - - tile_mask[j*(KQ_per_iter/2 + 4) + i] = mask_h2[j*stride_mask + i]; - } -#endif // CP_ASYNC_AVAILABLE } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -123,9 +405,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float logit_softcap, const int ne01, const int ne02, - const int stride_KV, + const int stride_K, + const int stride_V, const int stride_mask, const int jt, + half2 * const __restrict__ tile_Q, half2 * const __restrict__ tile_K, half2 * const __restrict__ tile_V, half2 * const __restrict__ tile_mask, @@ -135,59 +419,113 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( float * const __restrict__ KQ_rowsum, const int kb0) { #ifdef NEW_MMA_AVAILABLE + typedef fattn_mma_f16_config c; + +#ifdef CP_ASYNC_AVAILABLE + constexpr int nstages = c::nstages_target; +#else + constexpr int nstages = 0; +#endif // CP_ASYNC_AVAILABLE + constexpr int cols_per_warp = ntiles * tile_B::I; constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. - constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + constexpr int ncols = ncols1 * ncols2; + constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols); + constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols); - const int k_VKQ_0 = kb0 * KQ_per_iter; - tile_C_KQ KQ_C[KQ_per_iter/(np*tile_C_KQ::I) * ntiles]; + constexpr int stride_tile_Q = DKQ/2 + 4; + constexpr int stride_tile_K = nbatch_K2 + 4; + + static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); + constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; + + const int k_VKQ_0 = kb0 * c::nbatch_fa; + tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles]; // Use wide variants of tiles if ntiles >= 2. tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C; -#ifdef CP_ASYNC_AVAILABLE - cp_async_wait_all(); - __syncthreads(); - flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); -#else - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); - } - flash_attn_ext_f16_load_tile(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV); - __syncthreads(); -#endif // CP_ASYNC_AVAILABLE - - // Calculate tile of KQ: -#pragma unroll - for (int i_KQ_00 = 0; i_KQ_00 < KQ_per_iter; i_KQ_00 += np*tile_A::I) { - const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) { - tile_A K_A; - load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded); - if (ntiles == 1) { - mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]); - } else { -#pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - // Wide version of KQ_C is column-major => swap A and B. - mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A); - } - } + if constexpr (nstages > 1) { + static_assert(!mla, "multi-stage loading not implemented for MLA"); + static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); + constexpr bool use_cp_async = true; + cp_async_wait_all(); + __syncthreads(); + flash_attn_ext_f16_load_tile + (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V); + } else { + constexpr bool use_cp_async = nstages == 1; + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask); } } -#ifndef CP_ASYNC_AVAILABLE - __syncthreads(); // Only needed if tile_K == tile_V. -#endif // CP_ASYNC_AVAILABLE +#pragma unroll + for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) { + const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2; + const int k0_diff = k0_stop - k0_start; + + if (nstages <= 1) { + constexpr bool use_cp_async = nstages == 1; + flash_attn_ext_f16_load_tile + (K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K); + if (use_cp_async) { + cp_async_wait_all(); + } + __syncthreads(); + } + + // Calculate tile of KQ: + if constexpr (c::Q_in_reg) { +#pragma unroll + for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; +#pragma unroll + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) { + tile_A K_A; + load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); + if (ntiles == 1) { + mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + // Wide version of KQ_C is column-major => swap A and B. + mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A); + } + } + } + } + } else { + static_assert(ntiles == 2, "ntiles != 2 not implemented"); +#pragma unroll + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) { + load_ldmatrix(Q_B_16[0], tile_Q + (threadIdx.y / np)*(tile_B_16::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); + +#pragma unroll + for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I; + + tile_A K_A; + load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); + + // Wide version of KQ_C is column-major => swap A and B. + mma(KQ_C_16[i_KQ_00/(np*tile_A::I)], Q_B_16[0], K_A); + } + } + } + + if (nstages <= 1) { + __syncthreads(); // Only needed if tile_K == tile_V. + } + } if (use_logit_softcap) { - static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int i = 0; i < KQ_per_iter/(np*tile_C_KQ::I) * ntiles; ++i) { + for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) { #pragma unroll for (int l = 0; l < tile_C_KQ::ne; ++l) { KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); @@ -205,7 +543,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( if (ntiles == 1) { if (ncols2 > 1 || mask_h2) { #pragma unroll - for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ::I) { + for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) { const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I; #pragma unroll for (int l = 0; l < tile_C_KQ::ne; ++l) { @@ -213,16 +551,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2; KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope * - __half2float(((const half *) tile_mask)[j*(KQ_per_iter + 8) + i]); + __half2float(((const half *) tile_mask)[j*(c::nbatch_fa + 8) + i]); } } } // Calculate softmax for each KQ column using the current max. value. // The divisor is stored in KQ_rowsum and will be applied at the end. - static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) { + for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) { #pragma unroll for (int l = 0; l < tile_C_KQ::ne; ++l) { KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]); @@ -238,10 +576,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } - static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); - + static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) { + for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) { #pragma unroll for (int l = 0; l < tile_C_KQ::ne; ++l) { KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]); @@ -252,7 +589,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } else { // ntiles > 1 if (ncols2 > 1 || mask_h2) { #pragma unroll - for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ_16::J) { + for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ_16::J) { const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J; #pragma unroll for (int t = 0; t < ntiles/2; ++t) { @@ -261,7 +598,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2; const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2; - const float2 tmp = __half22float2(tile_mask[j*(KQ_per_iter/2 + 4) + i]); + const float2 tmp = __half22float2(tile_mask[j*(c::nbatch_fa/2 + 4) + i]); const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t; KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x; KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y; @@ -272,9 +609,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // Calculate softmax for each KQ column using the current max. value. // The divisor is stored in KQ_rowsum and will be applied at the end. - static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size"); + static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) { + for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) { #pragma unroll for (int t = 0; t < ntiles/2; ++t) { #pragma unroll @@ -294,9 +631,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } - static_assert(KQ_per_iter % (np*tile_C_KQ_16::J) == 0, "bad loop size"); + static_assert(c::nbatch_fa % (np*tile_C_KQ_16::J) == 0, "bad loop size"); #pragma unroll - for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) { + for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) { #pragma unroll for (int t = 0; t < ntiles/2; ++t) { #pragma unroll @@ -315,9 +652,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( float KQ_max_scale[cols_per_thread]; #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { - KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]); + const float KQ_max_diff = KQ_max[col] - KQ_max_new[col]; + KQ_max_scale[col] = expf(KQ_max_diff); KQ_max[col] = KQ_max_new[col]; + *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD; + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col]; } @@ -325,7 +665,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( if (ntiles == 1) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); #pragma unroll - for (int i = 0; i < D/tile_C_VKQ::I; ++i) { + for (int i = 0; i < DV/tile_C_VKQ::I; ++i) { #pragma unroll for (int l = 0; l < tile_C_VKQ::ne; ++l) { VKQ_C[i].x[l] *= KQ_max_scale_h2; @@ -336,7 +676,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int col = 0; col < cols_per_thread; ++col) { const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); #pragma unroll - for (int i = 0; i < D/tile_C_VKQ_16::J; ++i) { + for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) { #pragma unroll for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) { VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2; @@ -347,16 +687,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } // Convert KQ C tiles into B tiles for VKQ calculation: - tile_B B[KQ_per_iter/(np*2*tile_B::J) * ntiles]; + tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles]; tile_B_16 * B_16 = (tile_B_16 *) B; - static_assert(KQ_per_iter % (np*2*tile_B::J) == 0, "bad loop size"); + static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size"); if (ntiles == 1) { #pragma unroll - for (int k = 0; k < KQ_per_iter/(np*2*tile_B::J); ++k) { + for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) { B[k] = get_transposed(get_half2(KQ_C[k])); } } else { - for (int k = 0; k < KQ_per_iter/(np*2*tile_B_16::J); ++k) { + for (int k = 0; k < c::nbatch_fa/(np*2*tile_B_16::J); ++k) { #pragma unroll for (int t = 0; t < ntiles/2; ++t) { B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]); @@ -364,62 +704,83 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } -#ifdef CP_ASYNC_AVAILABLE - // Preload K tile for next iteration: - cp_async_wait_all(); - __syncthreads(); - if (!last_iter) { - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask(mask_h2 + (k_VKQ_0 + KQ_per_iter)/2, tile_mask, stride_mask); + if (nstages > 1) { + // Preload K tile for next iteration: + constexpr bool use_cp_async = true; + cp_async_wait_all(); + __syncthreads(); + if (!last_iter) { + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask + (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile + (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K); } - flash_attn_ext_f16_load_tile(K_h2 + (k_VKQ_0 + KQ_per_iter)*stride_KV, tile_K, stride_KV); } -#else - flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV); - __syncthreads(); -#endif // CP_ASYNC_AVAILABLE - // Calculate VKQ tile: -#pragma unroll - for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) { - static_assert((KQ_per_iter/2) % (np*tile_A::J) == 0, "bad loop size"); -#pragma unroll - for (int k00 = 0; k00 < KQ_per_iter/2; k00 += np*tile_A::J) { - const int k0 = k00 + (threadIdx.y % np)*tile_A::J; - tile_A A; - load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded); - if (ntiles == 1) { - mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); - } else { + // For MLA K and V have the same data. + // Therefore, iterate over V in reverse and re-use the data if possible. + static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented"); + constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV; #pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - // Wide version of VKQ_C is column-major => swap A and B. - mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A); + for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) { + const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0; + const int i0_diff = i0_stop - i0_start; + + if (nstages <= 1 && i0_start < reusable_cutoff) { + constexpr bool use_cp_async = nstages == 1; + flash_attn_ext_f16_load_tile + (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); + if (use_cp_async) { + cp_async_wait_all(); + } + __syncthreads(); + } + const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2; + + // Calculate VKQ tile: +#pragma unroll + for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += tile_C_VKQ::I) { + static_assert((c::nbatch_fa/2) % (np*tile_A::J) == 0, "bad loop size"); +#pragma unroll + for (int k00 = 0; k00 < c::nbatch_fa/2; k00 += np*tile_A::J) { + const int k0 = k00 + (threadIdx.y % np)*tile_A::J; + + tile_A A; + load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); + if (ntiles == 1) { + mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]); + } else { +#pragma unroll + for (int t = 0; t < ntiles/2; ++t) { + // Wide version of VKQ_C is column-major => swap A and B. + mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A); + } } } } + + if (nstages <= 1) { + __syncthreads(); // Only needed if tile_K == tile_V. + } } - -#ifndef CP_ASYNC_AVAILABLE - __syncthreads(); // Only needed if tile_K == tile_V. -#endif // CP_ASYNC_AVAILABLE - #else GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_KV); + GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B); GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum); - GGML_UNUSED(kb0); + GGML_UNUSED(kb0); GGML_UNUSED(tile_Q); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } -template +template static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -434,7 +795,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int ne02, const int stride_Q1, const int stride_Q2, - const int stride_KV, + const int stride_K, + const int stride_V, const int stride_mask, const int jt, const int kb0_start, @@ -442,29 +804,37 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #ifdef NEW_MMA_AVAILABLE //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + typedef fattn_mma_f16_config c; + +#ifdef CP_ASYNC_AVAILABLE + constexpr int nstages = c::nstages_target; +#else + constexpr int nstages = 0; +#endif // CP_ASYNC_AVAILABLE + constexpr int ncols = ncols1 * ncols2; constexpr int cols_per_warp = ntiles * tile_B::I; constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles; constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols); + constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols); static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps"); - static_assert(D % nwarps == 0, "bad D"); - static_assert(KQ_per_iter % nwarps == 0, "bad KQ_per_iter"); + constexpr int stride_tile_Q = DKQ/2 + 4; + constexpr int stride_tile_K = nbatch_K2 + 4; - constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. + static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); + constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; + constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V; - // Temporary shared buffer for loading K/V data with KQ_per_iter*D logical elements: - extern __shared__ half2 tile_K[]; -#ifdef CP_ASYNC_AVAILABLE - half2 * tile_V = tile_K + KQ_per_iter*D2_padded; -#else - half2 * tile_V = tile_K; -#endif // CP_ASYNC_AVAILABLE - half2 * tile_mask = tile_V + KQ_per_iter*D2_padded; + extern __shared__ half2 tile_Q[]; + half2 * tile_K = c::Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q; + half2 * tile_V = nstages > 1 ? tile_K + c::nbatch_fa * stride_tile_K : tile_K; + half2 * tile_mask = nstages > 1 ? tile_V + c::nbatch_fa * stride_tile_V : tile_V + c::nbatch_fa * stride_tile_KV_max; - tile_B Q_B[D/(2*tile_B::J) * ntiles]; - tile_C_VKQ VKQ_C[D/tile_C_VKQ::I * ntiles]; + tile_B Q_B[(c::Q_in_reg ? DKQ/(2*tile_B::J) : 1) * ntiles]; + tile_C_VKQ VKQ_C[DV/tile_C_VKQ::I * ntiles]; tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B; tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C; @@ -476,13 +846,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( KQ_max[col] = -FLT_MAX/2.0f; } - // Temporarily load Q data into tile_K, will be loaded into registers afterwards. + // Load Q data into tile_Q, either temporarily or permanently. + // Q in registers is faster, but register pressure is the biggest bottleneck. // The loading is done with decreasing granularity for D for better memory bandwidth. const half2 scale_h2 = make_half2(scale, scale); #pragma unroll for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); - const int k0_stop = D/2 - (D/2) % (1*stride_k); + const int k0_start = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k); + const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k); const int stride_jc = WARP_SIZE / stride_k; if (k0_start == k0_stop) { @@ -506,14 +877,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k]; - tile_K[jc*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y); + tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y); } } else { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - tile_K[jc*D2_padded + k] = make_half2(0.0f, 0.0f); + tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f); } } } @@ -521,18 +892,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); - { + if (c::Q_in_reg) { const int j0 = (threadIdx.y / np) * cols_per_warp; #pragma unroll - for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { + for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) { if (ntiles == 1) { - load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded); + load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q); } else { #pragma unroll for (int t = 0; t < ntiles/2; ++t) { load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t], - tile_K + (j0 + t*tile_B_16::I)*D2_padded + k0, D2_padded); + tile_Q + (j0 + t*tile_B_16::I)*stride_tile_Q + k0, stride_tile_Q); } } } @@ -540,35 +911,37 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( __syncthreads(); - // Preload mask and K data for first iteration when using cp_async: -#ifdef CP_ASYNC_AVAILABLE - if (ncols2 > 1 || mask_h2) { - flash_attn_ext_f16_load_mask(mask_h2 + kb0_start*KQ_per_iter/2, tile_mask, stride_mask); + // Preload mask and K data for first iteration when using cp_async with multiple stages: + if constexpr (nstages > 1) { + static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline"); + constexpr bool use_cp_async = true; + if (ncols2 > 1 || mask_h2) { + flash_attn_ext_f16_load_mask + (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask); + } + flash_attn_ext_f16_load_tile + (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K); } - flash_attn_ext_f16_load_tile(K_h2 + kb0_start*KQ_per_iter*stride_KV, tile_K, stride_KV); -#endif // CP_ASYNC_AVAILABLE // Iterate over ne11 == previous tokens: for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { constexpr bool last_iter = false; - flash_attn_ext_f16_iter + flash_attn_ext_f16_iter (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); + ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); } { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. constexpr bool last_iter = true; - flash_attn_ext_f16_iter + flash_attn_ext_f16_iter (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); + ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); } - // With cp_async there is no __syncthreads at the end of the iter, + // With multi-stage loading there is no __syncthreads at the end of the iter, // there can be a race condition on shared memory access for combining/writing back results. -#ifdef CP_ASYNC_AVAILABLE - if (nwarps*cols_per_warp > KQ_per_iter) { + if (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) { __syncthreads(); } -#endif // CP_ASYNC_AVAILABLE // Finally, sum up partial KQ rowsums. // The partial sums are spread across 8/4 threads each, does not need full reduce. @@ -584,38 +957,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } - // Write VKQ accumulators to shared memory in column-major format. - // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. - // Also for np > 1 the combination is done via these values in shared memory. - if (ntiles == 1) { - const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data -#pragma unroll - for (int k0 = 0; k0 < D/2; k0 += tile_B::J) { - const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format. + // Combine VKQ accumulator values if np > 1. + // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. + // So also write VKQ accumulators to shared memory in column-major format if np == 1. -#pragma unroll - for (int l = 0; l < tile_B::ne; ++l) { - const int k = k0 + tile_B::get_j(l); - - tile_K[jc_cwd*D2_padded + k] = B.x[l]; - } - } - } else { -#pragma unroll - for (int t = 0; t < ntiles/2; ++t) { - const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I; -#pragma unroll - for (int k0 = 0; k0 < D/2; k0 += tile_C_VKQ_16::J) { -#pragma unroll - for (int l = 0; l < tile_C_VKQ_16::ne; ++l) { - const int j = j0 + tile_C_VKQ_16::get_i(l); - const int k = k0 + tile_C_VKQ_16::get_j(l); - - tile_K[j*D2_padded + k] = VKQ_C_16[k0/tile_C_VKQ_16::J * ntiles/2 + t].x[l]; - } - } - } - } + constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols); + constexpr int tile_stride = nbatch_combine + 4; + static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine"); if constexpr (ntiles == 1) { const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset @@ -624,7 +972,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) { // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. - ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr; + ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr; } __syncthreads(); @@ -649,7 +997,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) { // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. - ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr; + ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr; } __syncthreads(); @@ -676,11 +1024,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1; const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x); - float2 * const meta_ptr = ((float2 *) tile_K) + jc_meta*(D2_padded/2) + D/4; + float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2; float2 meta[nmeta]; #pragma unroll for (int imeta = 0; imeta < nmeta; ++imeta) { - meta[imeta] = meta_ptr[imeta * WARP_SIZE * D2_padded/2]; + meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2]; } float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps. @@ -690,10 +1038,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } #pragma unroll for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { - if (offset >= WARP_SIZE) { - continue; + if (offset < WARP_SIZE) { + KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); } - KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); } float KQ_cms[nmeta]; // KQ combine max scale per warp. @@ -709,18 +1056,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } #pragma unroll for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { - if (offset >= WARP_SIZE) { - continue; + if (offset < WARP_SIZE) { + KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); } - KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); } + __syncthreads(); + // Write back combined meta data: #pragma unroll for (int imeta = 0; imeta < nmeta; ++imeta) { if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) { // Combined KQ max scale + rowsum. - meta_ptr[imeta * WARP_SIZE * D2_padded/2] = make_float2(KQ_cms[imeta], KQ_crs); + meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs); } } @@ -734,90 +1082,125 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); } - } - - if (np > 1) { + } else if (np > 1) { + // Warps with threadIdx.y % np == 0 execute a __syncthreads() in the if branch. + // Therefore, all other warps also need to execute a __syncthreads(). + // Otherwise the points at which warps synchronize with each other would become misaligned. __syncthreads(); } - if (np == 1 || threadIdx.y % np == 0) { - // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums. - // The values after that are for the partial results of the individual blocks. - float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(D/2)); +#pragma unroll + for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) { + if (ntiles == 1) { + const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data +#pragma unroll + for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) { + const tile_B B = get_transposed(VKQ_C[(k00 + k0)/tile_B::J]); // Conversion of C to B matrix puts it in column-major format. #pragma unroll - for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); - const int k0_stop = D/2 - (D/2) % (1*stride_k); - const int stride_jc = WARP_SIZE / stride_k; + for (int l = 0; l < tile_B::ne; ++l) { + const int k = k0 + tile_B::get_j(l); - if (k0_start == k0_stop) { - continue; + tile_Q[jc_cwd*tile_stride + k] = B.x[l]; + } } - + } else { #pragma unroll - for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) { - const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); - - if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) { - break; - } - - const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp; - - const int j_dst = jc_dst / ncols2; - const int c_dst = jc_dst % ncols2; - - if (!is_fixup && jt*ncols1 + j_dst >= ne01) { - continue; - } - - const float * meta_j = (const float *) tile_K + jc_tile_K*D2_padded + D/2; + for (int t = 0; t < ntiles/2; ++t) { + const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I; #pragma unroll - for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); - - float2 dstk_val = make_float2(0.0f, 0.0f); + for (int k0 = 0; k0 < nbatch_combine; k0 += tile_C_VKQ_16::J) { #pragma unroll - for (int ip = 0; ip < np; ++ip) { - const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * D2_padded + 0]; - const float2 dstk_val_add = __half22float2(tile_K[(jc_tile_K + ip*cols_per_warp) * D2_padded + k]); - dstk_val.x += dstk_val_add.x*KQ_crs; - dstk_val.y += dstk_val_add.y*KQ_crs; - } + for (int l = 0; l < tile_C_VKQ_16::ne; ++l) { + const int j = j0 + tile_C_VKQ_16::get_i(l); + const int k = k0 + tile_C_VKQ_16::get_j(l); - if (!needs_fixup && !is_fixup) { - const float KQ_rowsum_j = meta_j[1]; - dstk_val.x /= KQ_rowsum_j; - dstk_val.y /= KQ_rowsum_j; - } - - if (is_fixup) { - dstk_fixup_data[jc_dst*(D/2) + k] = dstk_val; - } else { - dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(D/2) + k] = dstk_val; + tile_Q[j*tile_stride + k] = VKQ_C_16[(k00 + k0)/tile_C_VKQ_16::J * ntiles/2 + t].x[l]; } } } } - } - if (np > 1) { __syncthreads(); + + if (np == 1 || threadIdx.y % np == 0) { + // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums. + // The values after that are for the partial results of the individual blocks. + float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2)); + +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k); + const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k); + const int stride_jc = WARP_SIZE / stride_k; + + if (k0_start == k0_stop) { + continue; + } + +#pragma unroll + for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) { + const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) { + break; + } + + const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp; + + const int j_dst = jc_dst / ncols2; + const int c_dst = jc_dst % ncols2; + + if (!is_fixup && jt*ncols1 + j_dst >= ne01) { + continue; + } + + const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine; +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + float2 dstk_val = make_float2(0.0f, 0.0f); +#pragma unroll + for (int ip = 0; ip < np; ++ip) { + const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * tile_stride + 0]; + const float2 dstk_val_add = __half22float2(tile_Q[(jc_tile_K + ip*cols_per_warp) * tile_stride + k]); + dstk_val.x += dstk_val_add.x*KQ_crs; + dstk_val.y += dstk_val_add.y*KQ_crs; + } + + if (!needs_fixup && !is_fixup) { + const float KQ_rowsum_j = meta_j[1]; + dstk_val.x /= KQ_rowsum_j; + dstk_val.y /= KQ_rowsum_j; + } + + if (is_fixup) { + dstk_fixup_data[jc_dst*(DV/2) + k00 + k] = dstk_val; + } else { + dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(DV/2) + k00 + k] = dstk_val; + } + } + } + } + } + if (np > 1) { + __syncthreads(); + } } #else GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1); - GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_KV); GGML_UNUSED(stride_mask); + GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop); NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } -template -__launch_bounds__(nwarps*WARP_SIZE, 2) +template +__launch_bounds__(nwarps*WARP_SIZE, 1) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, @@ -857,24 +1240,36 @@ static __global__ void flash_attn_ext_f16( #if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) // Skip unused kernel variants for faster compilation: - if (use_logit_softcap && !(D == 128 || D == 256)) { + if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { NO_DEVICE_CODE; return; } +#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING + if (ncols1*ncols2 > 32) { + NO_DEVICE_CODE; + return; + } +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING - static_assert(FATTN_KQ_STRIDE % KQ_per_iter == 0, "bad KQ_per_iter"); + static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV"); + + typedef fattn_mma_f16_config c; + + static_assert(FATTN_KQ_STRIDE % fattn_mma_f16_config::nbatch_fa == 0, "bad nbatch_fa"); const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. const int stride_Q1 = nb01 / sizeof(float2); const int stride_Q2 = nb02 / sizeof(float2); - const int stride_KV = nb11 / sizeof(half2); + const int stride_K = nb11 / sizeof(half2); const int stride_mask = nb31 / sizeof(half2); + const int stride_V = mla ? stride_K : nb21 / sizeof(half2); + const int iter_k = ne11 / FATTN_KQ_STRIDE; const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; - constexpr int kb_niter = FATTN_KQ_STRIDE / KQ_per_iter; // Number of kernel iterations per assigned KQ slice. + constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice. // kbc == k block continuous, current index in continuous ijk space. int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x; @@ -893,9 +1288,10 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; - float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2); + float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; @@ -905,14 +1301,14 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } else { constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); } kbc += iter_k; @@ -931,9 +1327,10 @@ static __global__ void flash_attn_ext_f16( const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2); const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr; - float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2); + float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2); + + const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; @@ -942,9 +1339,9 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); + ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); @@ -960,28 +1357,44 @@ static __global__ void flash_attn_ext_f16( #endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) } -template +template void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - constexpr int ncols = ncols1 * ncols2; - constexpr int KQ_per_iter = D <= 128 && ncols1 <= 64 ? 64 : 32; - constexpr int nwarps = (KQ_per_iter == 32 && ncols <= 16) ? 2 : 4; - constexpr int ntiles = ncols <= 8 ? 1 : (ncols <= 64 ? 2 : 4); - constexpr int cols_per_warp = ntiles * tile_B::I; + const ggml_tensor * KQV = dst; + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; - static_assert(D % tile_B::J == 0, "bad D"); + typedef fattn_mma_f16_config c; + + const int nstages = cp_async_available(cc) ? c::nstages_target : 0; + + constexpr int ncols = ncols1 * ncols2; + constexpr int ntiles = ncols <= 8 ? 1 : 2; // Number of tiles per warp. + constexpr int cols_per_warp = ntiles * tile_B::I; + constexpr int nwarps_max_x = ncols / cols_per_warp; + constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I; + constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max; + + constexpr bool mla = DKQ == 576; + + const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols); + const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols); + const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols); + + static_assert(DKQ % tile_B::J == 0, "bad DKQ"); + static_assert(DV % tile_A::J == 0, "bad DV"); static_assert(ncols % cols_per_warp == 0, "bad ncols"); - const ggml_tensor * KQV = dst; - const int id = ggml_cuda_get_device(); - const int cc = ggml_cuda_info().devices[id].cc; + const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2); + const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2); + const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2); - const int KQ_shared_rows = cp_async_available(cc) ? 2*KQ_per_iter : KQ_per_iter; + const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage; - const size_t nbytes_shared_KV = KQ_shared_rows * (D + 8) * sizeof(half); - const size_t nbytes_shared_mask = ncols1 * (KQ_per_iter + 8) * sizeof(half); - const size_t nbytes_shared_combine = nwarps*cols_per_warp * (D + 8) * sizeof(half); - - const size_t nbytes_shared_total = std::max(nbytes_shared_KV + nbytes_shared_mask, nbytes_shared_combine); + const size_t nbytes_shared_total = std::max(nbytes_shared_combine, c::Q_in_reg ? + std::max(nbytes_shared_Q, nbytes_shared_KV + nbytes_shared_mask) : + nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask); float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); @@ -989,59 +1402,73 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml fattn_kernel_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; + +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!shared_memory_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); + shared_memory_limit_raised[id] = true; + } +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) } else { constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; + +#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!shared_memory_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); + shared_memory_limit_raised[id] = true; + } +#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) } - launch_fattn + launch_fattn (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true); } -#define DECL_FATTN_MMA_F16_CASE(D, ncols1, ncols2) \ - template void ggml_cuda_flash_attn_ext_mma_f16_case \ - (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ +#define DECL_FATTN_MMA_F16_CASE(DKQ, DV, ncols1, ncols2) \ + template void ggml_cuda_flash_attn_ext_mma_f16_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ -#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(D, ncols) \ - extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/1, 1); \ - extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/2, 2); \ - extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \ - extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \ +#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(DKQ, DV, ncols) \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 1, 1); \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 2, 2); \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 4, 4); \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 8, 8); \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/16, 16); \ -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 8) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 16) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 32) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64) -DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64) -// Kernels with ncols == 128 are only 4% faster due to register pressure. -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128) -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128) -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128) -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128) -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128) -// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128) // Needs too much shared memory. +// The number of viable configurations for Deepseek is very limited: +extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index e0039e175..9283560d5 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -307,7 +307,7 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn + launch_fattn (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false); } break; case 128: { @@ -315,7 +315,7 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn + launch_fattn (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false); } break; default: { diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index fcb6f848f..32673adb5 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -318,7 +318,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn + launch_fattn (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false); } break; case 128: { @@ -326,7 +326,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * constexpr int nwarps = 8; constexpr size_t nbytes_shared = 0; fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn + launch_fattn (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false); } break; default: { diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index e17d2d0e4..35e649cb3 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -2,9 +2,9 @@ #include "fattn-common.cuh" template // D == head size -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +#ifndef GGML_USE_HIP __launch_bounds__(D, 1) -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +#endif // GGML_USE_HIP static __global__ void flash_attn_vec_ext_f16( const char * __restrict__ Q, const char * __restrict__ K, @@ -48,6 +48,12 @@ static __global__ void flash_attn_vec_ext_f16( NO_DEVICE_CODE; return; } +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + if (ncols > 1) { + NO_DEVICE_CODE; + return; + } +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) //In this kernel Q, K, V are matrices while i, j, k are matrix indices. @@ -91,6 +97,13 @@ static __global__ void flash_attn_vec_ext_f16( kqsum_shared[j][threadIdx.x] = 0.0f; } } + + __shared__ half maskh_shared[ncols*D]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + maskh_shared[j*D + tid] = 0.0f; + } + __syncthreads(); // Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers: @@ -168,12 +181,43 @@ static __global__ void flash_attn_vec_ext_f16( for (int j = 0; j < ncols; ++j) { KQ[j*D + tid] = -HALF_MAX_HALF; } + __syncthreads(); half2 VKQ[ncols] = {{0.0f, 0.0f}}; for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { // Calculate KQ tile and keep track of new maximum KQ values: + if (mask) { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + k_VKQ_0 + tid]; + } + + __syncthreads(); + + // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out. + // In such cases, skip the KV slice. + // On AMD __all_sync would not work correctly because it assumes a warp size of 64. +#ifndef GGML_USE_HIP + bool skip = true; +#pragma unroll + for (int j = 0; j < ncols; ++j) { +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + const float2 tmp = __half22float2(((const half2 *) maskh_shared)[j*(D/2) + i]); + skip = skip && isinf(tmp.x) && isinf(tmp.y); + } + } + if (__all_sync(0xFFFFFFFF, skip)) { + __syncthreads(); + continue; + } +#endif // GGML_USE_HIP + } + // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression, // see https://github.com/ggerganov/llama.cpp/pull/7061 . // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable). @@ -201,7 +245,7 @@ static __global__ void flash_attn_vec_ext_f16( sum = logit_softcap*tanhf(sum); } - sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); + sum += maskh_shared[j*D + i_KQ]; if (ncols == 1) { kqmax_new = ggml_cuda_hmax(kqmax_new, sum); @@ -315,7 +359,7 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; constexpr size_t nbytes_shared = 0; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); } template @@ -334,7 +378,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - if (Q->ne[1] == 1) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + + if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) { constexpr int cols_per_block = 1; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index d42ddca49..953967917 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -2,9 +2,9 @@ #include "fattn-common.cuh" template // D == head size -#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +#ifndef GGML_USE_HIP __launch_bounds__(D, 1) -#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) +#endif // GGML_USE_HIP static __global__ void flash_attn_vec_ext_f32( const char * __restrict__ Q, const char * __restrict__ K, @@ -60,6 +60,12 @@ static __global__ void flash_attn_vec_ext_f32( NO_DEVICE_CODE; return; } +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + if (ncols > 1) { + NO_DEVICE_CODE; + return; + } +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) //In this kernel Q, K, V are matrices while i, j, k are matrix indices. @@ -104,6 +110,13 @@ static __global__ void flash_attn_vec_ext_f32( kqsum_shared[j][threadIdx.x] = 0.0f; } } + + __shared__ float maskf_shared[ncols*D]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + maskf_shared[j*D + tid] = 0.0f; + } + __syncthreads(); // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers: @@ -181,6 +194,35 @@ static __global__ void flash_attn_vec_ext_f32( for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { // Calculate KQ tile and keep track of new maximum KQ values: + if (mask) { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + k_VKQ_0 + tid]); + } + + __syncthreads(); + + // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out. + // In such cases, skip the KV slice. + // On AMD __all_sync would not work correctly because it assumes a warp size of 64. +#ifndef GGML_USE_HIP + bool skip = true; +#pragma unroll + for (int j = 0; j < ncols; ++j) { +#pragma unroll + for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + skip = skip && isinf(maskf_shared[j*D + i]); + } + } + if (__all_sync(0xFFFFFFFF, skip)) { + __syncthreads(); + continue; + } +#endif // GGML_USE_HIP + } + float kqmax_new_arr[ncols]; #pragma unroll for (int j = 0; j < ncols; ++j) { @@ -204,7 +246,7 @@ static __global__ void flash_attn_vec_ext_f32( sum = logit_softcap*tanhf(sum); } - sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; + sum += maskf_shared[j*D + i_KQ]; kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum); @@ -310,7 +352,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, constexpr bool need_f16_K = D != 128; constexpr bool need_f16_V = D != 128 && D != 64; constexpr size_t nbytes_shared = 0; - launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); + launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); } template @@ -326,7 +368,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - if (Q->ne[1] == 1) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + + if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) { constexpr int cols_per_block = 1; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index bc21b27a0..c5668adb1 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -490,7 +490,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm fattn_kernel = flash_attn_ext_f16< D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>; } - launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size); + launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size); } void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 7a2d1e453..6bc0096cc 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -8,58 +8,33 @@ #include "fattn-wmma-f16.cuh" #include "fattn.cuh" -template +template static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const ggml_tensor * Q = dst->src[0]; - if (Q->ne[1] <= 8/ncols2) { - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); - return; + if constexpr (ncols2 <= 8) { + if (Q->ne[1] <= 8/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } } if (Q->ne[1] <= 16/ncols2) { - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } - if (Q->ne[1] <= 32/ncols2) { - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); } -template -static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - - switch (Q->ne[0]) { - case 64: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst); - break; - case 80: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst); - break; - case 96: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst); - break; - case 112: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst); - break; - case 128: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst); - break; - case 256: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst); - break; - default: - GGML_ABORT("fatal error"); - break; - } -} - -static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +template +static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; @@ -68,27 +43,79 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg float max_bias = 0.0f; memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); - const float use_gqa_opt = mask && max_bias == 0.0f; + const bool use_gqa_opt = mask && max_bias == 0.0f; GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; if (use_gqa_opt && gqa_ratio % 8 == 0) { - ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); return; } - if (use_gqa_opt && gqa_ratio == 4) { - ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst); + if (use_gqa_opt && gqa_ratio % 4 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); return; } - if (use_gqa_opt && gqa_ratio == 2) { - ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst); + if (use_gqa_opt && gqa_ratio % 2 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); return; } - ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); +} + +static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + + switch (Q->ne[0]) { + case 64: + GGML_ASSERT(V->ne[0] == 64); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 64, 64>(ctx, dst); + break; + case 80: + GGML_ASSERT(V->ne[0] == 80); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 80, 80>(ctx, dst); + break; + case 96: + GGML_ASSERT(V->ne[0] == 96); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 96, 96>(ctx, dst); + break; + case 112: + GGML_ASSERT(V->ne[0] == 112); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<112, 112>(ctx, dst); + break; + case 128: + GGML_ASSERT(V->ne[0] == 128); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst); + break; + case 256: + GGML_ASSERT(V->ne[0] == 256); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); + break; + case 576: { + // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels. + GGML_ASSERT(V->ne[0] == 512); + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + const bool use_gqa_opt = mask && max_bias == 0.0f; + GGML_ASSERT(use_gqa_opt); + + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + GGML_ASSERT(gqa_ratio % 16 == 0); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + } break; + default: + GGML_ABORT("fatal error"); + break; + } } #define FATTN_VEC_F16_CASE(D, type_K, type_V) \ @@ -299,7 +326,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion; - const bool can_use_vector_kernel = Q->ne[0] % (2*warp_size) == 0; + const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0; if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { if (prec == GGML_PREC_DEFAULT) { ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 4cef53a98..963e4d03d 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -10,10 +10,11 @@ static __global__ void k_get_rows( /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2; - const int i10 = blockDim.y*blockIdx.y + threadIdx.y; - const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12; - const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12; + // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. + const int i00 = (blockIdx.y * blockDim.x + threadIdx.x)*2; + const int i10 = blockIdx.x; + const int i11 = blockIdx.z / ne12; + const int i12 = blockIdx.z % ne12; if (i00 >= ne00) { return; @@ -33,8 +34,8 @@ static __global__ void k_get_rows( dfloat2 v; dequantize_kernel(src0_row, ib, iqs, v); - dst_row[iybs + iqs + 0] = v.x; - dst_row[iybs + iqs + y_offset] = v.y; + dst_row[iybs + iqs + 0] = float(v.x); + dst_row[iybs + iqs + y_offset] = float(v.y); } template @@ -46,10 +47,11 @@ static __global__ void k_get_rows_float( /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - const int i00 = blockIdx.x*blockDim.x + threadIdx.x; - const int i10 = blockDim.y*blockIdx.y + threadIdx.y; - const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12; - const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12; + // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. + const int i00 = blockIdx.y * blockDim.x + threadIdx.x; + const int i10 = blockIdx.x; + const int i11 = blockIdx.z / ne12; + const int i12 = blockIdx.z % ne12; if (i00 >= ne00) { return; @@ -60,7 +62,7 @@ static __global__ void k_get_rows_float( dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03); - dst_row[i00] = src0_row[i00]; + dst_row[i00] = float(src0_row[i00]); } template @@ -86,120 +88,159 @@ static __global__ void k_get_rows_back_float( dst[dst_row*ncols + col] = sum; } -template -static void get_rows_cuda( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { - - GGML_TENSOR_BINARY_OP_LOCALS - +template +static void get_rows_cuda_q( + const void * src0_d, const int32_t * src1_d, dst_t * dst_d, + const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12, + const size_t nb1, const size_t nb2, const size_t nb3, + cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); - const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); - const dim3 block_nums(block_num_x, ne10, ne11*ne12); + const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); + const dim3 block_nums(ne10, block_num_y, ne11*ne12); // strides in elements - //const size_t s0 = nb0 / ggml_element_size(dst); - const size_t s1 = nb1 / ggml_element_size(dst); - const size_t s2 = nb2 / ggml_element_size(dst); - const size_t s3 = nb3 / ggml_element_size(dst); + // const size_t s0 = nb0 / sizeof(dst_t); + const size_t s1 = nb1 / sizeof(dst_t); + const size_t s2 = nb2 / sizeof(dst_t); + const size_t s3 = nb3 / sizeof(dst_t); - const size_t s10 = nb10 / ggml_element_size(src1); - const size_t s11 = nb11 / ggml_element_size(src1); - const size_t s12 = nb12 / ggml_element_size(src1); - //const size_t s13 = nb13 / ggml_element_size(src1); + const size_t s10 = nb10 / sizeof(int32_t); + const size_t s11 = nb11 / sizeof(int32_t); + const size_t s12 = nb12 / sizeof(int32_t); + // const size_t s13 = nb13 / sizeof(int32_t); GGML_ASSERT(ne00 % 2 == 0); k_get_rows<<>>( - src0_dd, src1_dd, dst_dd, + src0_d, src1_d, dst_d, ne00, /*ne01, ne02, ne03,*/ /*ne10, ne11,*/ ne12, /*ne13,*/ /* s0,*/ s1, s2, s3, /* nb00,*/ nb01, nb02, nb03, s10, s11, s12/*, s13*/); - - GGML_UNUSED(dst); } -template +template static void get_rows_cuda_float( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) { - - GGML_TENSOR_BINARY_OP_LOCALS - - GGML_ASSERT(ne13 == 1); - + const src0_t * src0_d, const int32_t * src1_d, dst_t * dst_d, + const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12, + const size_t nb1, const size_t nb2, const size_t nb3, + cudaStream_t stream) { const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); - const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; - const dim3 block_nums(block_num_x, ne10, ne11*ne12); + const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; + const dim3 block_nums(ne10, block_num_y, ne11*ne12); // strides in elements - //const size_t s0 = nb0 / ggml_element_size(dst); - const size_t s1 = nb1 / ggml_element_size(dst); - const size_t s2 = nb2 / ggml_element_size(dst); - const size_t s3 = nb3 / ggml_element_size(dst); + // const size_t s0 = nb0 / sizeof(dst_t); + const size_t s1 = nb1 / sizeof(dst_t); + const size_t s2 = nb2 / sizeof(dst_t); + const size_t s3 = nb3 / sizeof(dst_t); - const size_t s10 = nb10 / ggml_element_size(src1); - const size_t s11 = nb11 / ggml_element_size(src1); - const size_t s12 = nb12 / ggml_element_size(src1); - //const size_t s13 = nb13 / ggml_element_size(src1); + const size_t s10 = nb10 / sizeof(int32_t); + const size_t s11 = nb11 / sizeof(int32_t); + const size_t s12 = nb12 / sizeof(int32_t); + // const size_t s13 = nb13 / sizeof(int32_t); k_get_rows_float<<>>( - src0_dd, src1_dd, dst_dd, + src0_d, src1_d, dst_d, ne00, /*ne01, ne02, ne03,*/ /*ne10, ne11,*/ ne12, /*ne13,*/ /* s0,*/ s1, s2, s3, /* nb00,*/ nb01, nb02, nb03, s10, s11, s12/*, s13*/); +} - GGML_UNUSED(dst); +template +static void ggml_cuda_get_rows_switch_src0_type( + const void * src0_d, const ggml_type src0_type, const int32_t * src1_d, dst_t * dst_d, + const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12, + const size_t nb1, const size_t nb2, const size_t nb3, + cudaStream_t stream) { + switch (src0_type) { + case GGML_TYPE_F16: + get_rows_cuda_float((const half *) src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + case GGML_TYPE_F32: + get_rows_cuda_float((const float *) src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + case GGML_TYPE_BF16: + get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + case GGML_TYPE_Q4_0: + get_rows_cuda_q(src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + case GGML_TYPE_Q4_1: + get_rows_cuda_q(src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + case GGML_TYPE_Q5_0: + get_rows_cuda_q(src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + case GGML_TYPE_Q5_1: + get_rows_cuda_q(src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + case GGML_TYPE_Q8_0: + get_rows_cuda_q(src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + default: + // TODO: k-quants + GGML_ABORT("%s: unsupported src0 type: %s\n", __func__, ggml_type_name(src0_type)); + break; + } +} + +void get_rows_cuda( + const void * src0_d, ggml_type src0_type, const int32_t * src1_d, void * dst_d, ggml_type dst_type, + int64_t ne00, size_t nb01, size_t nb02, size_t nb03, + int64_t ne10, int64_t ne11, int64_t ne12, size_t nb10, size_t nb11, size_t nb12, + size_t nb1, size_t nb2, size_t nb3, + cudaStream_t stream) { + switch (dst_type) { + case GGML_TYPE_F32: + ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + case GGML_TYPE_F16: + ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + case GGML_TYPE_BF16: + ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (nv_bfloat16 *) dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + default: + GGML_ABORT("%s: unsupported dst type: %s\n", __func__, ggml_type_name(dst_type)); + break; + } } void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - const void * src0_d = (const void *) src0->data; - const int32_t * src1_d = (const int32_t *) src1->data; - float * dst_d = (float *) dst->data; - cudaStream_t stream = ctx.stream(); + GGML_TENSOR_BINARY_OP_LOCALS + GGML_ASSERT(src1->type == GGML_TYPE_I32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(ne13 == 1); GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type)); GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type)); - switch (src0->type) { - case GGML_TYPE_F16: - get_rows_cuda_float(src0, src1, dst, (const half *) src0_d, src1_d, dst_d, stream); - break; - case GGML_TYPE_F32: - get_rows_cuda_float(src0, src1, dst, (const float *) src0_d, src1_d, dst_d, stream); - break; - case GGML_TYPE_Q4_0: - get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); - break; - case GGML_TYPE_Q4_1: - get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); - break; - case GGML_TYPE_Q5_0: - get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); - break; - case GGML_TYPE_Q5_1: - get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); - break; - case GGML_TYPE_Q8_0: - get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream); - break; - default: - // TODO: k-quants - GGML_ABORT("%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type)); - break; - } + get_rows_cuda(src0->data, src0->type, (const int32_t *) src1->data, dst->data, dst->type, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); } void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/getrows.cuh b/ggml/src/ggml-cuda/getrows.cuh index a1ca643f1..3c5bea5f4 100644 --- a/ggml/src/ggml-cuda/getrows.cuh +++ b/ggml/src/ggml-cuda/getrows.cuh @@ -3,6 +3,13 @@ #define CUDA_GET_ROWS_BLOCK_SIZE 256 #define CUDA_GET_ROWS_BACK_BLOCK_SIZE 256 +void get_rows_cuda( + const void * src0_d, ggml_type src0_type, const int32_t * src1_d, void * dst_d, ggml_type dst_type, + int64_t ne00, size_t nb01, size_t nb02, size_t nb03, + int64_t ne10, int64_t ne11, int64_t ne12, size_t nb10, size_t nb11, size_t nb12, + size_t nb1, size_t nb2, size_t nb3, + cudaStream_t stream); + void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index fafe9633e..898b24341 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -96,31 +96,32 @@ int ggml_cuda_get_device() { static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) { ggml_cuda_set_device(device); -#if defined(GGML_USE_HIP) && defined(GGML_HIP_UMA) - auto res = hipMallocManaged(ptr, size); - if (res == hipSuccess) { - // if error we "need" to know why... - CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device)); - } - return res; -#else - -#if !defined(GGML_USE_HIP) cudaError_t err; if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) { err = cudaMallocManaged(ptr, size); +#if defined(GGML_USE_HIP) + if (err == hipSuccess) { + CUDA_CHECK(cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device)); + } + + // fall back to cudaMalloc if not supported (e.g. on Windows) + if (err == hipErrorNotSupported) { + static bool warned_unsupported = false; + if (!warned_unsupported) { + GGML_LOG_WARN("hipMallocManaged unsupported, falling back to hipMalloc.\n"); + warned_unsupported = true; + } + + err = cudaMalloc(ptr, size); + } +#endif // defined(GGML_USE_HIP) } else { err = cudaMalloc(ptr, size); } return err; -#else - return cudaMalloc(ptr, size); -#endif // !defined(GGML_USE_HIP) - -#endif } #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) @@ -242,10 +243,10 @@ static ggml_cuda_device_info ggml_cuda_init() { info.default_tensor_split[id] = total_vram; total_vram += prop.totalGlobalMem; - - info.devices[id].nsm = prop.multiProcessorCount; - info.devices[id].smpb = prop.sharedMemPerBlock; - info.devices[id].warp_size = prop.warpSize; + info.devices[id].integrated = prop.integrated; + info.devices[id].nsm = prop.multiProcessorCount; + info.devices[id].smpb = prop.sharedMemPerBlock; + info.devices[id].warp_size = prop.warpSize; #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) info.devices[id].smpbo = prop.sharedMemPerBlock; @@ -554,8 +555,8 @@ static enum ggml_status ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer if (ggml_is_quantized(tensor->type) && tensor->view_src == nullptr && ggml_backend_buffer_get_usage(buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE) { // initialize padding to 0 to avoid possible NaN values - size_t original_size = ggml_nbytes(tensor); - size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor); + const size_t original_size = ggml_nbytes(tensor); + const size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor); if (padded_size > original_size) { ggml_cuda_set_device(ctx->device); @@ -614,9 +615,8 @@ static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; ggml_cuda_set_device(ctx->device); - CUDA_CHECK(cudaDeviceSynchronize()); - CUDA_CHECK(cudaMemset(ctx->dev_ptr, value, buffer->size)); - CUDA_CHECK(cudaDeviceSynchronize()); + CUDA_CHECK(cudaMemsetAsync(ctx->dev_ptr, value, buffer->size, cudaStreamPerThread)); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); } static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = { @@ -678,6 +678,7 @@ static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_t if (ggml_is_quantized(tensor->type)) { if (ne0 % MATRIX_ROW_PADDING != 0) { + GGML_ASSERT(tensor->nb[0] == ggml_element_size(tensor)); size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); } } @@ -799,6 +800,7 @@ static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buff static enum ggml_status ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported + GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors"); ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context; ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context; @@ -850,6 +852,7 @@ static void ggml_backend_cuda_split_buffer_set_tensor(ggml_backend_buffer_t buff // split tensors must always be set in their entirety at once GGML_ASSERT(offset == 0); GGML_ASSERT(size == ggml_nbytes(tensor)); + GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors"); ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context; @@ -888,6 +891,7 @@ static void ggml_backend_cuda_split_buffer_get_tensor(ggml_backend_buffer_t buff // split tensors must always be set in their entirety at once GGML_ASSERT(offset == 0); GGML_ASSERT(size == ggml_nbytes(tensor)); + GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors"); ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context; @@ -969,6 +973,7 @@ static size_t ggml_backend_cuda_split_buffer_type_get_alignment(ggml_backend_buf static size_t ggml_backend_cuda_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { ggml_backend_cuda_split_buffer_type_context * ctx = (ggml_backend_cuda_split_buffer_type_context *)buft->context; + GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors"); size_t total_size = 0; @@ -1059,6 +1064,10 @@ static const char * ggml_backend_cuda_host_buffer_type_name(ggml_backend_buffer_ GGML_UNUSED(buft); } +static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_cuda_host_buffer_type_name; +} + static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { CUDA_CHECK(cudaFreeHost(buffer->context)); } @@ -1134,7 +1143,6 @@ typedef void (*ggml_cuda_op_mul_mat_t)( static cudaError_t ggml_cuda_cpy_tensor_2d( void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) { - GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer)); const char * src_ptr = (const char *) src->data; char * dst_ptr = (char *) dst; @@ -1409,11 +1417,14 @@ static void ggml_cuda_op_mul_mat( const int64_t ne0 = dst->ne[0]; const int64_t ne1 = dst->ne[1]; + // const int64_t nb10 = src1->nb[0]; + const int64_t nb11 = src1->nb[1]; + const int64_t nb12 = src1->nb[2]; + const int64_t nb13 = src1->nb[3]; + const int64_t nb2 = dst->nb[2]; const int64_t nb3 = dst->nb[3]; - GGML_ASSERT(ggml_backend_buffer_is_cuda(dst->buffer)); - GGML_ASSERT(ggml_backend_buffer_is_cuda(src1->buffer)); ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context; ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context; @@ -1525,6 +1536,8 @@ static void ggml_cuda_op_mul_mat( // If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared: if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) { + GGML_ASSERT(ggml_is_contiguously_allocated(src0)); + GGML_ASSERT(!src0->view_src); const size_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00); const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING); CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream)); @@ -1544,7 +1557,10 @@ static void ggml_cuda_op_mul_mat( dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size); if (src1_on_device && src1_is_contiguous) { - quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream); + quantize_src1( + dev[id].src1_ddf, nullptr, dev[id].src1_ddq, src0->type, ne10, + nb11/sizeof(float), nb12/sizeof(float), nb13/sizeof(float), + src1_padded_col_size, ne11, ne12, ne13, stream); CUDA_CHECK(cudaGetLastError()); } } @@ -1639,7 +1655,9 @@ static void ggml_cuda_op_mul_mat( } if (quantize_src1 && !src1_is_contiguous) { - quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream); + quantize_src1( + src1_ddf_i, nullptr, src1_ddq_i, src0->type, ne10, ne10, ne11*ne10, ne12*ne11*ne10, + src1_padded_col_size, src1_ncols, 1, 1, stream); CUDA_CHECK(cudaGetLastError()); } @@ -1709,15 +1727,15 @@ static __global__ void k_compute_batched_ptrs( size_t nb12, size_t nb13, size_t nbd2, size_t nbd3, int64_t r2, int64_t r3) { - int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x; - int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y; + const int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x; + const int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y; if (i13 >= ne13 || i12 >= ne12) { return; } - int64_t i03 = i13 / r3; - int64_t i02 = i12 / r2; + const int64_t i03 = i13 / r3; + const int64_t i02 = i12 / r2; ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03; ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13; @@ -1728,9 +1746,13 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); - GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer)); + GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft)); GGML_ASSERT(src0->type == GGML_TYPE_F16); + // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst. + // As long as dst is contiguous this does not matter though. + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_TENSOR_BINARY_OP_LOCALS const int64_t ne_dst = ggml_nelements(dst); @@ -1739,21 +1761,31 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream)); - void * src0_ddq = src0->data; - half * src0_f16 = (half *) src0_ddq; - float * src1_ddf = (float *) src1->data; - float * dst_ddf = (float *) dst->data; + const half * src0_f16 = (const half *) src0->data; + float * dst_ddf = (float *) dst->data; + + const half * src1_f16 = (const half *) src1->data; + const size_t ts_src1 = ggml_type_size(src1->type); + GGML_ASSERT(nb10 == ts_src1); + int64_t s11 = nb11 / ts_src1; + int64_t s12 = nb12 / ts_src1; + int64_t s13 = nb13 / ts_src1; + ggml_cuda_pool_alloc src1_f16_alloc(ctx.pool()); // convert src1 to fp16 - ggml_cuda_pool_alloc src1_f16_alloc(ctx.pool()); if (src1->type != GGML_TYPE_F16) { - const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); + const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda(src1->type); const int64_t ne_src1 = ggml_nelements(src1); src1_f16_alloc.alloc(ne_src1); GGML_ASSERT(to_fp16_cuda != nullptr); - to_fp16_cuda(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream); + + to_fp16_cuda(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream); + + src1_f16 = src1_f16_alloc.get(); + s11 = ne10; + s12 = ne11*s11; + s13 = ne12*s12; } - half * src1_f16 = src1->type == GGML_TYPE_F16 ? (half *) src1_ddf : src1_f16_alloc.get(); ggml_cuda_pool_alloc dst_f16(ctx.pool()); char * dst_t; @@ -1813,13 +1845,13 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co int i02 = i12 / r2; CUBLAS_CHECK( - cublasGemmEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, - ne01, ne11, ne10, - alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half), - (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float), - beta, ( char *) dst_t + i12*nbd2 + i13*nbd3, cu_data_type, ne01, - cu_compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + cublasGemmEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + alpha, (const char *) src0_f16 + i03*nb03 + i02*nb02, CUDA_R_16F, nb01/sizeof(half), + src1_f16 + i13*s13 + i12*s12, CUDA_R_16F, s11, + beta, ( char *) dst_t + i13*nbd3 + i12*nbd2, cu_data_type, ne0, + cu_compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } } } @@ -1830,15 +1862,15 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co CUBLAS_CHECK( cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, - alpha, (const char *) src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA - (const char *) src1_f16, CUDA_R_16F, nb11/nb10, nb12/nb10, // strideB - beta, ( char *) dst_t, cu_data_type, ne01, nb2/nb0, // strideC + alpha, src0_f16, CUDA_R_16F, nb01/nb00, nb02/nb00, // strideA + src1_f16, CUDA_R_16F, s11, s12, // strideB + beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC ne12*ne13, cu_compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } else { // use cublasGemmBatchedEx - const int ne23 = ne12*ne13; + const int64_t ne23 = ne12*ne13; ggml_cuda_pool_alloc ptrs_src(ctx.pool(), 2*ne23); ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23); @@ -1850,8 +1882,8 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co ne12, ne13, ne23, nb02, nb03, - src1->type == GGML_TYPE_F16 ? nb12 : nb12/2, - src1->type == GGML_TYPE_F16 ? nb13 : nb13/2, + src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half), + src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half), nbd2, nbd3, r2, r3); CUDA_CHECK(cudaGetLastError()); @@ -1860,8 +1892,8 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F, nb01/nb00, - (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, nb11/nb10, - beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne01, + (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F, s11, + beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0, ne23, cu_compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); @@ -1877,13 +1909,19 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft); - bool use_mul_mat_vec = (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) + // If src0 is a temporary compute buffer it may have some padding that needs to be cleared for mul_mat_vec_q or mul_mat_q. + // But if src0 is also a view of another tensor then this cannot be done safely because it may overwrite valid tensor data. + // Therefore, in such cases use cuBLAS. + const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE + && ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src; + + bool use_mul_mat_vec = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src0->ne[0] % 2 == 0 && src1->ne[1] == 1; - bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) + bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; - bool use_mul_mat_q = ggml_is_quantized(src0->type) + bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; bool any_gpus_with_slow_fp16 = false; @@ -1918,12 +1956,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); - if (!split && use_mul_mat_vec && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) { + if (!split && use_mul_mat_vec && (src0->ne[1] <= MMV_MAX_ROWS || any_gpus_without_fp16_mma)) { // the custom F16 vector kernel can be used over batched cuBLAS GEMM // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention) - ggml_cuda_mul_mat_vec(ctx, src0, src1, dst); - } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) - && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { + ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst); + } else if (!split && use_mul_mat_vec_q) { + ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst); + } else if (!split && use_mul_mat_q) { + ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst); + } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) && + !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // general KQ + KQV multi-batch without FlashAttention ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); } else if (use_mul_mat_vec) { @@ -1937,196 +1979,147 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor } } -struct mmid_row_mapping { - int32_t i1; - int32_t i2; -}; - -static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous, - int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping, - const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0, - int64_t ne11, int64_t ne10, - size_t nb11, size_t nb12) { - int32_t iid1 = blockIdx.x; - int32_t id = blockIdx.y; - - const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0); - - if (row_id_i != i02) { - return; - } - - const int64_t i11 = id % ne11; - const int64_t i12 = iid1; - - __shared__ int src1_row; - if (threadIdx.x == 0) { - src1_row = atomicAdd(cur_src1_row, 1); - row_mapping[src1_row] = {id, iid1}; - } - __syncthreads(); - - const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12); - float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11); - - for (int i = threadIdx.x; i < ne10; i += blockDim.x) { - src1_row_contiguous[i] = src1_row_original[i]; - } -} - -static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_original, const char * __restrict__ dst_contiguous, - const mmid_row_mapping * __restrict__ row_mapping, - int64_t ne0, - size_t nb1, size_t nb2) { - int32_t i = blockIdx.x; - - const int32_t i1 = row_mapping[i].i1; - const int32_t i2 = row_mapping[i].i2; - - const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1); - float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2); - - for (int j = threadIdx.x; j < ne0; j += blockDim.x) { - dst_row_original[j] = dst_row_contiguous[j]; - } -} - static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; const ggml_tensor * ids = dst->src[2]; + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft) && "mul_mat_id does not support split buffers"); + GGML_TENSOR_BINARY_OP_LOCALS - GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft) && "mul_mat_id does not support split buffers"); + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + if (ne2 == 1) { + if (ggml_is_quantized(src0->type)) { + ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); + } else { + ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst); + } + return; + } + + if (ggml_cuda_should_use_mmq(src0->type, cc, ne12)) { + ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst); + return; + } + } cudaStream_t stream = ctx.stream(); - const int64_t n_as = ne02; - const int64_t n_ids = ids->ne[0]; + GGML_ASSERT(nb12 % nb11 == 0); + GGML_ASSERT(nb2 % nb1 == 0); + + const ggml_type type_src1_sorted = (src0->type == GGML_TYPE_F16 && !fast_fp16_hardware_available(cc)) + || ggml_is_quantized(src0->type) ? GGML_TYPE_F32 : src0->type; + const ggml_type type_dst_sorted = GGML_TYPE_F32; + const size_t ts_src1_sorted = ggml_type_size(type_src1_sorted); + const size_t ts_dst_sorted = ggml_type_size(type_dst_sorted); + + const int64_t n_expert_used = ids->ne[0]; + const int64_t ne_get_rows = ne12 * n_expert_used; + + std::vector ids_to_sorted_host; + ids_to_sorted_host.reserve(2*ne_get_rows); + std::vector ids_from_sorted_host(ne_get_rows); + + ggml_cuda_pool_alloc ids_buf_dev(ctx.pool(), 2*ne_get_rows); + + std::vector tokens_per_expert(ne02); + + ggml_cuda_pool_alloc src1_sorted(ctx.pool(), ne12*n_expert_used*ne10*ts_src1_sorted); + ggml_cuda_pool_alloc dst_sorted(ctx.pool(), ne2 *n_expert_used* ne0*ts_dst_sorted); std::vector ids_host(ggml_nbytes(ids)); - const char * ids_dev = (const char *) ids->data; - CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); CUDA_CHECK(cudaStreamSynchronize(stream)); - ggml_tensor src0_row = *src0; - ggml_tensor src1_row = *src1; - ggml_tensor dst_row = *dst; - - char * src0_original = (char *) src0->data; - char * src1_original = (char *) src1->data; - char * dst_original = (char *) dst->data; - - src0_row.ne[2] = 1; - src0_row.ne[3] = 1; - src0_row.nb[3] = nb02; - - src1_row.ne[1] = 1; - src1_row.ne[2] = 1; - src1_row.ne[3] = 1; - src1_row.nb[2] = nb11; - src1_row.nb[3] = nb11; - - dst_row.ne[1] = 1; - dst_row.ne[2] = 1; - dst_row.ne[3] = 1; - dst_row.nb[2] = nb1; - dst_row.nb[3] = nb1; - - if (ne12 == 1) { - for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { - for (int64_t id = 0; id < n_ids; id++) { - const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - - GGML_ASSERT(i02 >= 0 && i02 < n_as); - - const int64_t i11 = id % ne11; - const int64_t i12 = iid1; - - const int64_t i1 = id; - const int64_t i2 = i12; - - src0_row.data = src0_original + i02*nb02; - src1_row.data = src1_original + i11*nb11 + i12*nb12; - dst_row.data = dst_original + i1*nb1 + i2*nb2; - - ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row); - } - } - } else { - ggml_cuda_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); - ggml_cuda_pool_alloc dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); - - src1_row.data = src1_contiguous.get(); - dst_row.data = dst_contiguous.get(); - - for (int64_t i02 = 0; i02 < n_as; i02++) { - int64_t num_src1_rows = 0; - - for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { - for (int64_t id = 0; id < n_ids; id++) { - const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - - GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); - - if (row_id_i != i02) { - continue; - } - - num_src1_rows++; + for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices + for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens + for (int64_t iex = 0; iex < n_expert_used; ++iex) { + const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]); + assert(expert_to_use >= 0 && expert_to_use < ne02); + if (expert_to_use == i02) { + ids_from_sorted_host[i12*n_expert_used + iex] = ids_to_sorted_host.size(); + ids_to_sorted_host.push_back(i12*ne11 + iex % ne11); + tokens_per_expert[i02]++; + break; } } - - if (num_src1_rows == 0) { - continue; - } - - ggml_cuda_pool_alloc dev_cur_src1_row(ctx.pool(), 1); - ggml_cuda_pool_alloc dev_row_mapping(ctx.pool(), num_src1_rows); - CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream)); - - { - dim3 block_dims(std::min((unsigned int)ne10, 768u)); - dim3 grid_dims(ids->ne[1], n_ids); - k_copy_src1_to_contiguous<<>>( - src1_original, src1_contiguous.get(), - dev_cur_src1_row.get(), dev_row_mapping.get(), - ids_dev, i02, ids->nb[1], ids->nb[0], - ne11, ne10, - nb11, nb12); - CUDA_CHECK(cudaGetLastError()); - } - - src0_row.data = src0_original + i02*nb02; - - GGML_ASSERT(nb11 == sizeof(float)*ne10); - GGML_ASSERT(nb1 == sizeof(float)*ne0); - - src1_row.ne[1] = num_src1_rows; - src1_row.nb[1] = nb11; - src1_row.nb[2] = num_src1_rows*nb11; - src1_row.nb[3] = num_src1_rows*nb11; - - dst_row.ne[1] = num_src1_rows; - dst_row.nb[1] = nb1; - dst_row.nb[2] = num_src1_rows*nb1; - dst_row.nb[3] = num_src1_rows*nb1; - - ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row); - - { - dim3 block_dims(std::min((unsigned int)ne0, 768u)); - dim3 grid_dims(num_src1_rows); - k_copy_dst_from_contiguous<<>>( - dst_original, dst_contiguous.get(), - dev_row_mapping.get(), - ne0, - nb1, nb2); - CUDA_CHECK(cudaGetLastError()); - } } } + GGML_ASSERT(ids_to_sorted_host.size() == size_t(ne_get_rows)); + + ids_to_sorted_host.insert(ids_to_sorted_host.end(), ids_from_sorted_host.begin(), ids_from_sorted_host.end()); + + CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_to_sorted_host.data(), 2*ne_get_rows*sizeof(int32_t), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + const int32_t * ids_to_sorted = ids_buf_dev.ptr + 0*ne_get_rows; + const int32_t * ids_from_sorted = ids_buf_dev.ptr + 1*ne_get_rows; + + get_rows_cuda(src1->data, src1->type, ids_to_sorted, src1_sorted.ptr, type_src1_sorted, + ne10, nb11, nb12, nb13, + ne_get_rows, 1, 1, sizeof(int32_t), ne_get_rows*sizeof(int32_t), ne_get_rows*sizeof(int32_t), + ne10*ts_src1_sorted, ne_get_rows*ne10*ts_src1_sorted, ne_get_rows*ne10*ts_src1_sorted, stream); + CUDA_CHECK(cudaGetLastError()); + + char * src1_data_cur = (char *) src1_sorted.ptr; + char * dst_data_cur = (char *) dst_sorted.ptr; + for (int64_t i02 = 0; i02 < ne02; ++i02) { + if (tokens_per_expert[i02] == 0) { + continue; + } + + ggml_tensor src0_slice = *src0; + src0_slice.ne[2] = 1; + src0_slice.nb[3] = src0_slice.nb[2]; + src0_slice.op = GGML_OP_VIEW; + src0_slice.view_src = dst->src[0]; // non-const pointer to src0 + src0_slice.data = (char *) src0->data + i02*nb02; + + ggml_tensor src1_slice; + memset(&src1_slice, 0, sizeof(src1_slice)); + src1_slice.buffer = src1->buffer; + src1_slice.type = type_src1_sorted; + src1_slice.ne[0] = ne10; + src1_slice.ne[1] = tokens_per_expert[i02]; + src1_slice.ne[2] = 1; + src1_slice.ne[3] = 1; + src1_slice.nb[0] = ts_src1_sorted; + src1_slice.nb[1] = src1_slice.ne[0] * src1_slice.nb[0]; + src1_slice.nb[2] = src1_slice.ne[1] * src1_slice.nb[1]; + src1_slice.nb[3] = src1_slice.ne[2] * src1_slice.nb[2]; + src1_slice.data = src1_data_cur; + + ggml_tensor dst_slice; + memset(&dst_slice, 0, sizeof(dst_slice)); + dst_slice.buffer = dst->buffer; + dst_slice.type = type_dst_sorted; + dst_slice.ne[0] = ne0; + dst_slice.ne[1] = tokens_per_expert[i02]; + dst_slice.ne[2] = 1; + dst_slice.ne[3] = 1; + dst_slice.nb[0] = ts_dst_sorted; + dst_slice.nb[1] = dst_slice.ne[0] * dst_slice.nb[0]; + dst_slice.nb[2] = dst_slice.ne[1] * dst_slice.nb[1]; + dst_slice.nb[3] = dst_slice.ne[2] * dst_slice.nb[2]; + dst_slice.data = dst_data_cur; + + ggml_cuda_mul_mat(ctx, &src0_slice, &src1_slice, &dst_slice); + CUDA_CHECK(cudaGetLastError()); + + src1_data_cur += src1_slice.nb[2]; + dst_data_cur += dst_slice.nb[2]; + } + + get_rows_cuda(dst_sorted.ptr, type_dst_sorted, ids_from_sorted, dst->data, dst->type, + ne0, ne0*ts_dst_sorted, ne_get_rows*ne0*ts_dst_sorted, ne_get_rows*ne0*ts_dst_sorted, + ne_get_rows, 1, 1, sizeof(int32_t), ne_get_rows*sizeof(int32_t), ne_get_rows*sizeof(int32_t), + nb1, nb2, nb3, stream); } static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) { @@ -2199,6 +2192,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_SILU: ggml_cuda_op_silu(ctx, dst); break; + case GGML_UNARY_OP_GELU_ERF: + ggml_cuda_op_gelu_erf(ctx, dst); + break; case GGML_UNARY_OP_GELU_QUICK: ggml_cuda_op_gelu_quick(ctx, dst); break; @@ -2488,10 +2484,10 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud #endif } - if (node->op == GGML_OP_MUL_MAT_ID) { + if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) { use_cuda_graph = false; // This node type is not supported by CUDA graph capture #ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__); + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__); #endif } @@ -2645,6 +2641,8 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) { + // flag used to determine whether it is an integrated_gpu + const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; while (!graph_evaluated_or_captured) { // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. @@ -2663,10 +2661,12 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx if (node->src[j] != nullptr) { assert(node->src[j]->buffer); assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || - ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft)); + ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft) || (integrated && ggml_backend_buft_is_cuda_host(node->src[j]->buffer->buft))); } } -#endif +#else + GGML_UNUSED(integrated); +#endif // NDEBUG bool ok = ggml_cuda_compute_forward(*cuda_ctx, node); if (!ok) { @@ -2984,6 +2984,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSWISH: + case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_EXP: @@ -2997,9 +2998,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g { struct ggml_tensor * a = op->src[0]; struct ggml_tensor * b = op->src[1]; - // for small weight matrices the active device can end up without any rows, don't use row split in those cases - // this avoids some edge cases (and the performance would not be good anyways) if (a->buffer && ggml_backend_buft_is_cuda_split(a->buffer->buft)) { + if (a->ne[2] > 1 || a->ne[3] > 1) { + return false; + } + // for small weight matrices the active device can end up without any rows, don't use row split in those cases + // this avoids some edge cases (and the performance would not be good anyways) ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) a->buffer->buft->context; int64_t row_low; int64_t row_high; @@ -3202,9 +3206,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g } case GGML_OP_ROPE: case GGML_OP_ROPE_BACK: { - const size_t ts = ggml_type_size(op->src[0]->type); - const int64_t ne0_012 = op->src[0]->ne[0] * op->src[0]->ne[1] * op->src[0]->ne[2]; - return op->src[0]->nb[0] == ts && op->src[0]->nb[3] == ne0_012*ts; + return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]); } case GGML_OP_IM2COL: case GGML_OP_POOL_2D: @@ -3230,8 +3232,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return false; #endif // FLASH_ATTN_AVAILABLE if (op->src[1]->ne[0] != op->src[2]->ne[0]) { - // different head sizes of K and V are not supported yet - return false; + const int cc = ggml_cuda_info().devices[dev_ctx->device].cc; + if (!new_mma_available(cc)) { + return false; + } + const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2]; + return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0; } if (op->src[0]->ne[0] == 192) { return false; @@ -3264,7 +3270,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g } static bool ggml_backend_cuda_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { - return (ggml_backend_buft_is_cuda(buft) || ggml_backend_buft_is_cuda_split(buft)) && buft->device == dev; + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context; + const bool integrated = ggml_cuda_info().devices[dev_ctx->device].integrated; + return (((ggml_backend_buft_is_cuda(buft) || ggml_backend_buft_is_cuda_split(buft)) && buft->device == dev) || (integrated && ggml_backend_buft_is_cuda_host(buft))); } static int64_t get_op_batch_size(const ggml_tensor * op) { diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index b36b43d54..2db5b4ab0 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -1,37 +1,10 @@ #include "mmq.cuh" +#include "quantize.cuh" -void ggml_cuda_op_mul_mat_q( - ggml_backend_cuda_context & ctx, - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, - const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, - const int64_t src1_padded_row_size, cudaStream_t stream) { +#include - const int64_t ne00 = src0->ne[0]; - - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - GGML_ASSERT(ne10 % QK8_1 == 0); - - const int64_t ne0 = dst->ne[0]; - - const int64_t row_diff = row_high - row_low; - const int64_t stride00 = ne00 / ggml_blck_size(src0->type); - - int id = ggml_cuda_get_device(); - const int cc = ggml_cuda_info().devices[id].cc; - - // the main device has a larger memory buffer to hold the results from all GPUs - // nrows_dst == nrows of the matrix that the kernel writes into - const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; - - // The stream-k decomposition is only faster for recent NVIDIA GPUs. - // Also its fixup needs to allocate a temporary buffer in the memory pool. - // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer. - const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && - ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11; - const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k}; - - switch (src0->type) { +static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { + switch (args.type_x) { case GGML_TYPE_Q4_0: mul_mat_q_case(ctx, args, stream); break; @@ -90,10 +63,208 @@ void ggml_cuda_op_mul_mat_q( GGML_ABORT("fatal error"); break; } +} + +void ggml_cuda_mul_mat_q( + ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { + GGML_ASSERT( src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID. + + GGML_TENSOR_BINARY_OP_LOCALS; + + cudaStream_t stream = ctx.stream(); + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + + const size_t ts_src0 = ggml_type_size(src0->type); + const size_t ts_src1 = ggml_type_size(src1->type); + const size_t ts_dst = ggml_type_size(dst->type); + + GGML_ASSERT( nb00 == ts_src0); + GGML_ASSERT( nb10 == ts_src1); + GGML_ASSERT( nb0 == ts_dst); + GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); + + const char * src0_d = (const char *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + // If src0 is a temporary compute buffer, clear any potential padding. + if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) { + const size_t size_data = ggml_nbytes(src0); + const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0); + if (size_alloc > size_data) { + GGML_ASSERT(ggml_is_contiguously_allocated(src0)); + GGML_ASSERT(!src0->view_src); + CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream)); + } + } + + const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING); + + const int64_t s01 = src0->nb[1] / ts_src0; + const int64_t s1 = dst->nb[1] / ts_dst; + const int64_t s02 = src0->nb[2] / ts_src0; + const int64_t s2 = dst->nb[2] / ts_dst; + const int64_t s03 = src0->nb[3] / ts_src0; + const int64_t s3 = dst->nb[3] / ts_dst; + + const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA; + + if (!ids) { + const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 + + get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq); + ggml_cuda_pool_alloc src1_q8_1(ctx.pool(), nbytes_src1_q8_1); + + { + const int64_t s11 = src1->nb[1] / ts_src1; + const int64_t s12 = src1->nb[2] / ts_src1; + const int64_t s13 = src1->nb[3] / ts_src1; + quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, + ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream); + CUDA_CHECK(cudaGetLastError()); + } + + const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int)); + const int64_t s13 = ne12*s12; + + const mmq_args args = { + src0_d, src0->type, (const int *) src1_q8_1.ptr, nullptr, nullptr, dst_d, + ne00, ne01, ne1, s01, ne11, s1, + ne02, ne12, s02, s12, s2, + ne03, ne13, s03, s13, s3, + use_stream_k}; + ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); + return; + } + + GGML_ASSERT(ne13 == 1); + GGML_ASSERT(nb12 % nb11 == 0); + GGML_ASSERT(nb2 % nb1 == 0); + + const int64_t n_expert_used = ids->ne[0]; + const int64_t ne_get_rows = ne12 * n_expert_used; + + std::vector ids_host(ggml_nbytes(ids)); + std::vector ids_src1_host; + ids_src1_host.reserve(ne_get_rows); + std::vector ids_dst_host; + ids_dst_host.reserve(ne_get_rows); + std::vector tokens_per_expert_host(ne02); + std::vector expert_bounds_host(ne02 + 1); + ggml_cuda_pool_alloc ids_buf_dev(ctx.pool()); + + CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices + for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens + for (int64_t iex = 0; iex < n_expert_used; ++iex) { + const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]); + assert(expert_to_use >= 0 && expert_to_use < ne02); + if (expert_to_use == i02) { + ids_src1_host.push_back(i12*(nb12/nb11) + iex % ne11); + ids_dst_host.push_back(i12*ne1 + iex); + tokens_per_expert_host[i02]++; + break; + } + } + } + } + + int32_t cumsum = 0; + for (int64_t i = 0; i < ne02; ++i) { + expert_bounds_host[i] = cumsum; + cumsum += tokens_per_expert_host[i]; + } + expert_bounds_host[ne02] = cumsum; + + std::vector ids_buf_host; + ids_buf_host.reserve(ids_src1_host.size() + ids_dst_host.size() + expert_bounds_host.size()); + ids_buf_host.insert(ids_buf_host.end(), ids_src1_host.begin(), ids_src1_host.end()); + ids_buf_host.insert(ids_buf_host.end(), ids_dst_host.begin(), ids_dst_host.end()); + ids_buf_host.insert(ids_buf_host.end(), expert_bounds_host.begin(), expert_bounds_host.end()); + ids_buf_dev.alloc(ids_buf_host.size() + get_mmq_x_max_host(cc)); // Expert bounds are padded on device. + CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_buf_host.data(), ids_buf_host.size()*sizeof(int32_t), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + const int32_t * ids_src1_dev = ids_buf_dev.ptr; + const int32_t * ids_dst_dev = ids_src1_dev + ids_src1_host.size(); + const int32_t * expert_bounds_dev = ids_dst_dev + ids_dst_host.size(); + + const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 + + get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq); + ggml_cuda_pool_alloc src1_q8_1(ctx.pool(), nbytes_src1_q8_1); + + const int64_t ne11_flat = ne12*n_expert_used; + const int64_t ne12_flat = 1; + const int64_t ne13_flat = 1; + + { + const int64_t s11 = src1->nb[1] / ts_src1; + const int64_t s12 = src1->nb[2] / ts_src1; + const int64_t s13 = src1->nb[2] / ts_src1; + quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type, + ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); + CUDA_CHECK(cudaGetLastError()); + } + + const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int)); + const int64_t s13 = ne12*s12; + + // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid. + const mmq_args args = { + src0_d, src0->type, (const int *) src1_q8_1.ptr, ids_dst_dev, expert_bounds_dev, dst_d, + ne00, ne01, ne_get_rows, s01, ne_get_rows, s1, + ne02, ne02, s02, s12, s2, + ne03, ne13, s03, s13, s3, + use_stream_k}; + + ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); +} + +void ggml_cuda_op_mul_mat_q( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream) { + + const int64_t ne00 = src0->ne[0]; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + GGML_ASSERT(ne10 % QK8_1 == 0); + + const int64_t ne0 = dst->ne[0]; + + const int64_t row_diff = row_high - row_low; + const int64_t stride01 = ne00 / ggml_blck_size(src0->type); + + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + + // the main device has a larger memory buffer to hold the results from all GPUs + // nrows_dst == nrows of the matrix that the kernel writes into + const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; + + // The stream-k decomposition is only faster for recent NVIDIA GPUs. + // Also its fixup needs to allocate a temporary buffer in the memory pool. + // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer. + const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && + ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11; + const mmq_args args = { + src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i, + ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst, + 1, 1, 0, 0, 0, + 1, 1, 0, 0, 0, + use_stream_k}; + + ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); GGML_UNUSED(src1); GGML_UNUSED(dst); GGML_UNUSED(src1_ddf_i); + GGML_UNUSED(src1_padded_row_size); } bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 532358018..80baf459c 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -13,9 +13,10 @@ using namespace ggml_cuda_mma; #define MMQ_ITER_K 256 #define MMQ_NWARPS 8 -typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int & kbx0, const int & i_max, const int & stride); -typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00); -typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max); +typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride); +typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00); +typedef void (*mmq_write_back_t)(const float * __restrict__ sum, const int32_t * __restrict__ get_rows_to_sorted, + float * __restrict__ dst, const int stride, const int i_max, const int j_max); enum mmq_q8_1_ds_layout { MMQ_Q8_1_DS_LAYOUT_D4, @@ -155,25 +156,27 @@ static constexpr __device__ int get_mmq_y_device() { #define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8} static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { - return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 : - type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 : - type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 : - type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K : - type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K : - type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K : - type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K : - type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K : - type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 : - type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 : - type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_IQ3_S ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_IQ1_S ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 : - type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 : - tile_x_sizes{0, 0, 0}; + switch (type) { + case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0; + case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1; + case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1; + case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K; + case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K; + case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K; + case GGML_TYPE_Q5_K: return MMQ_DP4A_TXS_Q5_K; + case GGML_TYPE_Q6_K: return MMQ_DP4A_TXS_Q6_K; + case GGML_TYPE_IQ2_XXS: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ2_XS: return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ2_S: return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ3_XXS: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ3_S: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ1_S: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_XS: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0; + default: return tile_x_sizes{0, 0, 0}; + } } #define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4) @@ -189,25 +192,27 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { - return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 : - type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 : - type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K : - type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K : - type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 : - type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 : - type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : - type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K : - type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K : - type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_IQ3_S ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_IQ1_S ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 : - type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 : - 0; + switch (type) { + case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K; + case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q5_K: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q6_K: return MMQ_MMA_TILE_X_K_Q6_K; + case GGML_TYPE_IQ2_XXS: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ2_XS: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ2_S: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ3_XXS: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ3_S: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ1_S: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_XS: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0; + default: return 0; + } } #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1) @@ -229,7 +234,7 @@ static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */ // ------------------------------------------------------------ template static __device__ __forceinline__ void load_tiles_q4_0( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -285,7 +290,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); const int * x_qs = (const int *) x; @@ -324,7 +329,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( } template static __device__ __forceinline__ void load_tiles_q4_1( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -380,7 +385,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); const int * x_qs = (const int *) x; @@ -419,7 +424,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( } template static __device__ __forceinline__ void load_tiles_q5_0( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -491,7 +496,7 @@ template static __device__ __forceinlin } template static __device__ __forceinline__ void load_tiles_q5_1( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -561,7 +566,7 @@ template static __device__ __forceinlin } template static __device__ __forceinline__ void load_tiles_q8_0( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -617,7 +622,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); const int * x_qs = (const int *) x; @@ -647,7 +652,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { typedef tile<16, 8, int> tile_A; typedef tile< 8, 8, int> tile_B; @@ -728,7 +733,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); const int * x_qs = (const int *) x; @@ -758,7 +763,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { typedef tile<16, 8, int> tile_A; typedef tile< 8, 8, int> tile_B; @@ -835,7 +840,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; const int * x_qs = (const int *) x; @@ -867,7 +872,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { #ifdef NEW_MMA_AVAILABLE typedef tile<16, 4, int> tile_A; @@ -951,7 +956,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( } template static __device__ __forceinline__ void load_tiles_q2_K( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -1007,7 +1012,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); const int * x_qs = (const int *) x; @@ -1070,7 +1075,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { #ifdef NEW_MMA_AVAILABLE typedef tile<16, 4, int> tile_A; @@ -1197,7 +1202,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( } template static __device__ __forceinline__ void load_tiles_q3_K( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -1294,7 +1299,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); const int * x_qs = (const int *) x; @@ -1336,7 +1341,7 @@ static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, co } template static __device__ __forceinline__ void load_tiles_q4_K( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -1433,7 +1438,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); const int * x_qs = (const int *) x; @@ -1465,7 +1470,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( } template static __device__ __forceinline__ void load_tiles_q5_K( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -1574,7 +1579,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); const int * x_qs = (const int *) x; @@ -1606,7 +1611,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( } template static __device__ __forceinline__ void load_tiles_q6_K( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -1689,7 +1694,7 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); const int * x_qs = (const int *) x; @@ -1722,7 +1727,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { #ifdef NEW_MMA_AVAILABLE typedef tile<16, 4, int> tile_A; @@ -1831,7 +1836,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( } template static __device__ __forceinline__ void load_tiles_iq4_nl( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -1889,7 +1894,7 @@ template static __device__ __forceinlin } template static __device__ __forceinline__ void load_tiles_iq2_xxs( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -1947,7 +1952,7 @@ template static __device__ __forceinlin } template static __device__ __forceinline__ void load_tiles_iq2_xs( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -2003,7 +2008,7 @@ template static __device__ __forceinlin } template static __device__ __forceinline__ void load_tiles_iq2_s( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -2066,7 +2071,7 @@ template static __device__ __forceinlin } template static __device__ __forceinline__ void load_tiles_iq3_xxs( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -2122,7 +2127,7 @@ template static __device__ __forceinlin } template static __device__ __forceinline__ void load_tiles_iq3_s( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -2185,7 +2190,7 @@ template static __device__ __forceinlin } template static __device__ __forceinline__ void load_tiles_iq1_s( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -2241,7 +2246,7 @@ template static __device__ __forceinlin } template static __device__ __forceinline__ void load_tiles_iq4_xs( - const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { #ifdef NEW_MMA_AVAILABLE int * x_qs = (int *) x_tile; @@ -2302,8 +2307,8 @@ template static __device__ __forceinlin template static __device__ __forceinline__ void mmq_write_back_dp4a( - const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) { - + const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst, + const int stride, const int i_max, const int j_max) { #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -2320,15 +2325,15 @@ static __device__ __forceinline__ void mmq_write_back_dp4a( continue; } - dst[j*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; + dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; } } } template static __device__ __forceinline__ void mmq_write_back_mma( - const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) { - + const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst, + const int stride, const int i_max, const int j_max) { typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); @@ -2358,7 +2363,7 @@ static __device__ __forceinline__ void mmq_write_back_mma( continue; } - dst[j*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l]; + dst[ids_dst[j]*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l]; } } } @@ -2514,17 +2519,18 @@ struct mmq_type_traits { }; template -static __device__ void mul_mat_q_process_tile( - const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup, - const int & ne00, const int & ne01, const int & stride01, const int & ne10, const int & ne11, const int & stride11, const int & ne0, - const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) { +static __device__ __forceinline__ void mul_mat_q_process_tile( + const char * __restrict__ x, const int offset_x, const int * __restrict__ y, + const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup, + const int stride_row_x, const int ncols_y, const int stride_col_dst, + const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) { constexpr int qk = ggml_cuda_type_traits::qk; constexpr int mmq_y = get_mmq_y_device(); constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; - extern __shared__ char data_mul_mat_q[]; - int * tile_y = (int *) data_mul_mat_q; + extern __shared__ int data_mul_mat_q[]; + int * tile_y = data_mul_mat_q + mmq_x; int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE); #ifdef NEW_MMA_AVAILABLE @@ -2539,16 +2545,11 @@ static __device__ void mul_mat_q_process_tile( float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f}; - const int tile_x_max_i = ne01 - it*mmq_y - 1; - const int tile_y_max_j = ne11 - jt*mmq_x - 1; - - const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int)); - for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) { - load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01); + load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x); { - const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int)); + const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int)); #pragma unroll for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) { int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x; @@ -2564,7 +2565,7 @@ static __device__ void mul_mat_q_process_tile( __syncthreads(); { - const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int)); + const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int)); #pragma unroll for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) { int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x; @@ -2581,12 +2582,10 @@ static __device__ void mul_mat_q_process_tile( } if (fixup) { - write_back(sum, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x); + write_back(sum, ids_dst, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x); } else { - write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j); + write_back(sum, ids_dst, dst, stride_col_dst, tile_x_max_i, tile_y_max_j); } - - GGML_UNUSED(ne00); GGML_UNUSED(ne10); } @@ -2605,8 +2604,11 @@ template #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) static __global__ void mul_mat_q( - const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup, - const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) { + const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst, + const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup, + const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst, + const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { // Skip unused template specializations for faster compilation: if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) { @@ -2617,26 +2619,88 @@ static __global__ void mul_mat_q( constexpr int qk = ggml_cuda_type_traits::qk; constexpr int mmq_y = get_mmq_y_device(); + const int ntx = (ncols_dst + mmq_x - 1) / mmq_x; // Number of tiles x + const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y + + // Initialize the ids for writing back data with just the index. + // For regular matrix multiplications this is never changed. + // For MoE the correct indices are loaded from ids_dst. + extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory. +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) { + const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x; + + if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = j; + } + __syncthreads(); + // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: #if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA { + const int wt = blockIdx.z / nchannels_y; + const int zt = blockIdx.z - wt*nchannels_y; + const int jt = blockIdx.y; + const int it = blockIdx.x; + + // Defaults for regular matrix multiplication: + int col_low = 0; + int col_high = ncols_dst; + int col_diff = ncols_dst; + int offset_y = wt*stride_sample_y + zt*stride_channel_y; + int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; + + if (ids_dst) { + col_low = expert_bounds[zt + 0]; + col_high = expert_bounds[zt + 1]; + col_diff = col_high - col_low; + + offset_y = 0; + offset_dst = 0; + + if (jt*mmq_x >= col_diff) { + return; + } + + // __syncthreads(); // There is no previous tile that could cause a race condition. +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) { + const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x; + + if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j]; + } + __syncthreads(); + } + + offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); + offset_dst += it*mmq_y; + + const int tile_x_max_i = nrows_x - it*mmq_y - 1; + const int tile_y_max_j = col_diff - jt*mmq_x - 1; + + const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + constexpr bool fixup = false; mul_mat_q_process_tile - (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0, - blockIdx.x, blockIdx.y, 0, ne00/qk); + (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, + tile_x_max_i, tile_y_max_j, 0, ncols_x/qk); return; } #endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA - const int64_t blocks_per_ne00 = ne00 / qk; + const int64_t blocks_per_ne00 = ncols_x / qk; constexpr int blocks_per_iter = MMQ_ITER_K / qk; - const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x - const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y - // kbc == k block continuous, current index in continuous ijk space. - int64_t kbc = (int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x; - int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x; + int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter; @@ -2645,13 +2709,66 @@ static __global__ void mul_mat_q( int kb0_start = kbc % blocks_per_ne00; int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc); while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) { - const int jt = kbc / (blocks_per_ne00*nty); // j index of current tile. - const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; // i index of current tile. + int tmp = kbc; + const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); + tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); + const int zt = tmp / (ntx*blocks_per_ne00); + tmp -= zt * (ntx*blocks_per_ne00); + const int jt = tmp / blocks_per_ne00; + + // Defaults for regular matrix multiplication: + int col_low = 0; + int col_high = ncols_dst; + int col_diff = ncols_dst; + int offset_y = wt*stride_sample_y + zt*stride_channel_y; + int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; + + if (ids_dst) { + col_low = expert_bounds[zt + 0]; + col_high = expert_bounds[zt + 1]; + col_diff = col_high - col_low; + + offset_y = 0; + offset_dst = 0; + + if (jt*mmq_x >= col_diff) { + kbc += blocks_per_ne00; + kbc -= kbc % blocks_per_ne00; + + kb0_start = 0; + kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); + + continue; + } + + __syncthreads(); +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) { + const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x; + + if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j]; + } + __syncthreads(); + } + + offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); + offset_dst += it*mmq_y; + + const int tile_x_max_i = nrows_x - it*mmq_y - 1; + const int tile_y_max_j = col_diff - jt*mmq_x - 1; + + const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. mul_mat_q_process_tile - (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0, - it, jt, kb0_start, kb0_stop); + (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, + tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); kbc += blocks_per_ne00; kbc -= kbc % blocks_per_ne00; @@ -2664,55 +2781,108 @@ static __global__ void mul_mat_q( return; } - const int jt = kbc / (blocks_per_ne00*nty); - const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; + int tmp = kbc; + const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); + tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); + const int zt = tmp / (ntx*blocks_per_ne00); + tmp -= zt * (ntx*blocks_per_ne00); + const int jt = tmp / blocks_per_ne00; + + // Defaults for regular matrix multiplication: + int col_low = 0; + int col_high = ncols_dst; + int col_diff = ncols_dst; + int offset_y = wt*stride_sample_y + zt*stride_channel_y; + int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; + + if (ids_dst) { + col_low = expert_bounds[zt + 0]; + col_high = expert_bounds[zt + 1]; + col_diff = col_high - col_low; + + offset_y = 0; + offset_dst = 0; + + if (jt*mmq_x >= col_diff) { + return; + } + + // The memory layout for the fixup buffer is always contiguous, therefore reset ids: + __syncthreads(); +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) { + const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x; + + if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = j; + } + __syncthreads(); + } + + offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); + offset_dst += it*mmq_y; + + const int tile_x_max_i = nrows_x - it*mmq_y - 1; + const int tile_y_max_j = col_diff - jt*mmq_x - 1; + + const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. mul_mat_q_process_tile - (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0, - it, jt, kb0_start, kb0_stop); + (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, + tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); } template static __global__ void mul_mat_q_stream_k_fixup( - float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ne00, const int ne01, const int ne11, const int ne0, const int block_num_mmq) { - + const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile, + const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst, + const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) { constexpr int mmq_y = get_mmq_y_device(); constexpr int qk = ggml_cuda_type_traits::qk; constexpr int blocks_per_iter = MMQ_ITER_K / qk; - const int64_t blocks_per_ne00 = ne00 / qk; + const int64_t blocks_per_ne00 = ncols_x / qk; float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f}; - const int ntx = (ne11 + mmq_x - 1) / mmq_x; - const int nty = (ne01 + mmq_y - 1) / mmq_y; + const int ntx = (ncols_dst + mmq_x - 1) / mmq_x; + const int nty = (nrows_x + mmq_y - 1) / mmq_y; + + const int bidx0 = blockIdx.x; + + // kbc == k block continuous, current index in continuous ijk space. + int64_t kbc0 = (int64_t) bidx0 *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + int64_t kbc0_stop = (int64_t)(bidx0 + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + + kbc0 -= (kbc0 % blocks_per_ne00) % blocks_per_iter; + kbc0_stop -= (kbc0_stop % blocks_per_ne00) % blocks_per_iter; + + const bool did_not_have_any_data = kbc0 == kbc0_stop; + const bool wrote_beginning_of_tile = kbc0 % blocks_per_ne00 == 0; + const bool did_not_write_last = kbc0/blocks_per_ne00 == kbc0_stop/blocks_per_ne00 && kbc0_stop % blocks_per_ne00 != 0; + if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { + return; + } bool any_fixup = false; - const int bidx_start = ((blockIdx.y*nty + blockIdx.x) * block_num_mmq) / (gridDim.y*gridDim.x); - const int bidx_stop = ((blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq + gridDim.y*gridDim.x - 1) / (gridDim.y*gridDim.x); + // Iterate over previous blocks and sum up partial sums written to fixup buffer. + // All CUDA blocks that get here must have a previous block that needs a fixup. + int64_t bidx = bidx0 - 1; + int64_t kbc_stop = kbc0; + while(true) { + int64_t kbc = bidx*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; - int64_t kbc_0; - int64_t kbc_stop_0 = (int64_t) bidx_start*blocks_per_ne00*ntx*nty / block_num_mmq; - - for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) { - kbc_0 = kbc_stop_0; - kbc_stop_0 = (int64_t) (bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq; - - const int64_t kbc = kbc_0 - (kbc_0 % blocks_per_ne00) % blocks_per_iter; - const int64_t kbc_stop = kbc_stop_0 - (kbc_stop_0 % blocks_per_ne00) % blocks_per_iter; - - // Skip fixup tile if the MMQ CUDA block never wrote anything to it: - if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) { - continue; - } - - const int jt = kbc_stop / (blocks_per_ne00*nty); - const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; - - // Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block: - if ((unsigned)it != blockIdx.x || (unsigned)jt != blockIdx.y) { + if (kbc == kbc_stop) { // Did not have any data. + bidx--; + kbc_stop = kbc; continue; } @@ -2729,16 +2899,72 @@ static __global__ void mul_mat_q_stream_k_fixup( sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i]; } } + + // If this block started in a previous tile we are done and don't need to combine additional partial results. + if (kbc % blocks_per_ne00 == 0 || kbc/blocks_per_ne00 < kbc0/blocks_per_ne00) { + break; + } + bidx--; + kbc_stop = kbc; } if (!any_fixup) { return; } - dst += blockIdx.y*mmq_x*ne0 + blockIdx.x*mmq_y; + int tmp = kbc0; + const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); + tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); + const int zt = tmp / (ntx*blocks_per_ne00); + tmp -= zt * (ntx*blocks_per_ne00); + const int jt = tmp / blocks_per_ne00; - const int i_max = ne01 - blockIdx.x*mmq_y - 1; - const int j_max = ne11 - blockIdx.y*mmq_x - 1; + if (!ids_dst) { + const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y; + dst += offset_dst; + + const int i_max = nrows_x - it*mmq_y - 1; + const int j_max = ncols_dst - jt*mmq_x - 1; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j > j_max) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + if (need_check && i > i_max) { + continue; + } + + dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; + } + } + return; + } + + __shared__ int ids_dst_shared[mmq_x]; + const int col_low = expert_bounds[zt + 0]; + const int col_high = expert_bounds[zt + 1]; + const int col_diff = col_high - col_low; + + for (int j = threadIdx.y*WARP_SIZE + threadIdx.x; j < mmq_x; j += nwarps*WARP_SIZE) { + ids_dst_shared[j] = ids_dst[col_low + j]; + } + __syncthreads(); + + const int offset_dst = it*mmq_y; + dst += offset_dst; + + const int i_max = nrows_x - it*mmq_y - 1; + const int j_max = col_diff - jt*mmq_x - 1; #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -2756,26 +2982,27 @@ static __global__ void mul_mat_q_stream_k_fixup( continue; } - dst[j*ne0 + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; + dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; } } } struct mmq_args { - const char * x; const char * y; float * dst; - int64_t ne00; int64_t ne01; int64_t stride01; - int64_t ne10; int64_t ne11; int64_t stride11; - int64_t ne0; + const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst; + int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst; + int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst; + int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst; bool use_stream_k; }; template -static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) { +static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc) { const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); - const int shmem_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); - const int shmem_y = mmq_x*sizeof(block_q8_1_mmq); - return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int)); + const size_t nbs_ids = mmq_x*sizeof(int); + const size_t nbs_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); + const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq); + return nbs_ids + nbs_x + GGML_PAD(nbs_y, MMQ_NWARPS*WARP_SIZE*sizeof(int)); } template @@ -2787,86 +3014,114 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1); - const int shmem = mmq_get_shmem(mmq_x, mmq_y, cc); + const int nbytes_shared = mmq_get_nbytes_shared(mmq_x, mmq_y, cc); #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) - static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; - if (!shmem_limit_raised[id]) { - CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); - CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); - shmem_limit_raised[id] = true; + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!shared_memory_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared)); + CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared)); + shared_memory_limit_raised[id] = true; } #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA) - const int nty = (args.ne01 + mmq_y - 1) / mmq_y; - const int ntx = (args.ne11 + mmq_x - 1) / mmq_x; - const dim3 block_nums_xy_tiling(nty, ntx, 1); + const int nty = (args.nrows_x + mmq_y - 1) / mmq_y; + const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x; + const int ntzw = args.nchannels_y * args.nsamples_y; + const dim3 block_nums_xy_tiling(nty, ntx, ntzw); + + GGML_ASSERT(args.nchannels_y % args.nchannels_x == 0); + GGML_ASSERT(args.nsamples_y % args.nsamples_x == 0); + const int channel_ratio = args.nchannels_y / args.nchannels_x; + const int sample_ratio = args.nsamples_y / args.nsamples_x; if (!args.use_stream_k) { - if (args.ne01 % mmq_y == 0) { + if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; - mul_mat_q<<>> - (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0); + mul_mat_q<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst); } else { constexpr bool need_check = true; - mul_mat_q<<>> - (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0); + mul_mat_q<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst); } return; } - const dim3 block_nums_mmq(nsm, 1, 1); + const dim3 block_nums_stream_k(nsm, 1, 1); + const bool fixup_needed = ntx*nty*ntzw % nsm != 0; ggml_cuda_pool & pool = ctx.pool(id); - ggml_cuda_pool_alloc tmp_fixup(pool, block_nums_mmq.x * mmq_x*mmq_y); + ggml_cuda_pool_alloc tmp_fixup(pool); + if (fixup_needed) { + tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y); + } - if (args.ne01 % mmq_y == 0) { + if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; - mul_mat_q<<>> - (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0); + mul_mat_q<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst); - mul_mat_q_stream_k_fixup<<>> - (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x); + if (!fixup_needed) { + return; + } + + mul_mat_q_stream_k_fixup<<>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, + args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst); } else { constexpr bool need_check = true; - mul_mat_q<<>> - (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0); + mul_mat_q<<>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst); - mul_mat_q_stream_k_fixup<<>> - (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x); + if (!fixup_needed) { + return; + } + + mul_mat_q_stream_k_fixup<<>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, + args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst); } } template void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { - const int id = ggml_cuda_get_device(); - const int cc = ggml_cuda_info().devices[id].cc; - const int smpbo = ggml_cuda_info().devices[id].smpbo; + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; const int mmq_x_max = get_mmq_x_max_host(cc); const int mmq_y = get_mmq_y_host(cc); - const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y; - const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA; int mmq_x_best = 0; - int nparts_best = INT_MAX; + int ntiles_x_best = INT_MAX; - for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) { + for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) { const int granularity = mmq_get_granularity_host(mmq_x, cc); - if (mmq_x % granularity != 0 || mmq_get_shmem(mmq_x, mmq_y, cc) > smpbo) { + if (mmq_x % granularity != 0 || mmq_get_nbytes_shared(mmq_x, mmq_y, cc) > smpbo) { continue; } - const int ntiles_x = (args.ne11 + mmq_x - 1) / mmq_x; - const int nwaves_xy_tiling = ntiles_x*block_num_y; - const int nparts = use_stream_k ? ntiles_x : nwaves_xy_tiling; + const int ntiles_x = (args.ncols_y + mmq_x - 1) / mmq_x; - if (nparts < nparts_best) { - mmq_x_best = mmq_x; - nparts_best = nparts; + if (ntiles_x < ntiles_x_best) { + mmq_x_best = mmq_x; + ntiles_x_best = ntiles_x; } } @@ -2950,6 +3205,9 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); // ------------------------------------------------------------------------------------------------------------------------- +void ggml_cuda_mul_mat_q( + ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); + void ggml_cuda_op_mul_mat_q( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, diff --git a/ggml/src/ggml-cuda/mmv.cu b/ggml/src/ggml-cuda/mmv.cu index b39961cd1..d8c385e23 100644 --- a/ggml/src/ggml-cuda/mmv.cu +++ b/ggml/src/ggml-cuda/mmv.cu @@ -4,18 +4,23 @@ template static __global__ void mul_mat_vec( - const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row, + const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, + const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row, const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) { - const int64_t row = blockIdx.x; - const int64_t channel = blockIdx.y; - const int64_t sample = blockIdx.z; - const int tid = threadIdx.x; - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + const int64_t row = blockIdx.x; + const int64_t channel_dst = blockIdx.y; + const int64_t channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio; + const int64_t channel_y = ids ? channel_dst % nchannels_y : channel_dst; + const int64_t sample_dst = blockIdx.z; + const int64_t sample_x = sample_dst / sample_ratio; + const int64_t sample_y = sample_dst; + const int tid = threadIdx.x; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - x += (sample/sample_ratio)*stride_sample_x + (channel/channel_ratio)*stride_channel_x + row*stride_row; - y += sample *stride_sample_y + channel *stride_channel_y; - dst += sample *stride_sample_dst + channel *stride_channel_dst; + x += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row; + y += sample_y *stride_sample_y + channel_y *stride_channel_y; + dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst; const float2 * y2 = (const float2 *) y; @@ -31,12 +36,19 @@ static __global__ void mul_mat_vec( float sumf = 0.0f; - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { + const float2 * x2 = (const float2 *) x; + + for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { + const float2 tmpx = x2[col2]; + const float2 tmpy = y2[col2]; + sumf += tmpx.x*tmpy.x; + sumf += tmpx.y*tmpy.y; + } + } else if constexpr (std::is_same::value) { const half2 * x2 = (const half2 *) x; if (std::is_same::value) { - sumf = 0.0f; - for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { const float2 tmpx = __half22float2(x2[col2]); const float2 tmpy = y2[col2]; @@ -59,8 +71,6 @@ static __global__ void mul_mat_vec( } } else if constexpr (std::is_same::value) { const int * x2 = (const int *) x; - sumf = 0.0f; - for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { const int tmpx = x2[col2]; const float2 tmpy = y2[col2]; @@ -92,17 +102,17 @@ static __global__ void mul_mat_vec( template static void launch_mul_mat_vec_cuda( - const T * x, const float * y, float * dst, - const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, - const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, cudaStream_t stream) { GGML_ASSERT(ncols % 2 == 0); GGML_ASSERT(stride_row % 2 == 0); - GGML_ASSERT(nchannels_y % nchannels_x == 0); - GGML_ASSERT(nsamples_y % nsamples_x == 0); - const int64_t channel_ratio = nchannels_y / nchannels_x; - const int64_t sample_ratio = nsamples_y / nsamples_x; + GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); + GGML_ASSERT( nsamples_dst % nsamples_x == 0); + const int64_t channel_ratio = nchannels_dst / nchannels_x; + const int64_t sample_ratio = nsamples_dst / nsamples_x; int device; int warp_size; @@ -124,48 +134,48 @@ static void launch_mul_mat_vec_cuda( } const int smem = warp_size*sizeof(float); - const dim3 block_nums(nrows, nchannels_y, nsamples_y); + const dim3 block_nums(nrows, nchannels_dst, nsamples_dst); const dim3 block_dims(block_size_best, 1, 1); switch (block_size_best) { case 32: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 64: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 96: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 128: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 160: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 192: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 224: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; case 256: { mul_mat_vec<<>> - (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } break; default: { GGML_ABORT("fatal error"); @@ -175,28 +185,28 @@ static void launch_mul_mat_vec_cuda( template static void mul_mat_vec_cuda( - const T * x, const float * y, float * dst, - const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, - const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, enum ggml_prec prec, cudaStream_t stream) { - switch (prec) { - case GGML_PREC_DEFAULT: { + if constexpr(std::is_same::value) { + if (prec == GGML_PREC_DEFAULT) { launch_mul_mat_vec_cuda - (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case GGML_PREC_F32: { - launch_mul_mat_vec_cuda - (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; + (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + return; + } } + launch_mul_mat_vec_cuda + (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); } -void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); +void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { + GGML_ASSERT( src1->type == GGML_TYPE_F32); + GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_TENSOR_BINARY_OP_LOCALS; @@ -204,21 +214,24 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * const size_t ts_src1 = ggml_type_size(src1->type); const size_t ts_dst = ggml_type_size(dst->type); - GGML_ASSERT(ne11 == 1); - GGML_ASSERT(ne12 == ne2); + GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1. GGML_ASSERT(ne13 == ne3); - GGML_ASSERT(nb00 == ts_src0); - GGML_ASSERT(nb10 == ts_src1); - GGML_ASSERT(nb0 == ts_dst); + GGML_ASSERT( nb00 == ts_src0); + GGML_ASSERT( nb10 == ts_src1); + GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); + GGML_ASSERT( nb0 == ts_dst); const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; - const float * src1_d = (const float *) src1->data; - float * dst_d = (float *) dst->data; + const float * src1_d = (const float *) src1->data; + const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; + float * dst_d = (float *) dst->data; const int64_t s01 = src0->nb[1] / ts_src0; + const int64_t s11 = src1->nb[1] / ts_src1; + const int64_t s1 = dst->nb[1] / ts_dst; const int64_t s02 = src0->nb[2] / ts_src0; const int64_t s12 = src1->nb[2] / ts_src1; const int64_t s2 = dst->nb[2] / ts_dst; @@ -226,14 +239,33 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * const int64_t s13 = src1->nb[3] / ts_src1; const int64_t s3 = dst->nb[3] / ts_dst; + // For MUL_MAT_ID the memory layout is different than for MUL_MAT: + const int64_t ncols_dst = ids ? ne2 : ne1; + const int64_t nchannels_y = ids ? ne11 : ne12; + const int64_t nchannels_dst = ids ? ne1 : ne2; + const int64_t stride_channel_dst = ids ? s1 : s2; + const int64_t stride_channel_y = ids ? s11 : s12; + + GGML_ASSERT(ncols_dst == 1); + switch (src0->type) { + case GGML_TYPE_F32: { + const float * src0_d = (const float *) src0->data; + mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01, + ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, + ne03, ne3, s03, s13, s3, prec, ctx.stream()); + } break; case GGML_TYPE_F16: { const half * src0_d = (const half *) src0->data; - mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream()); + mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01, + ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, + ne03, ne3, s03, s13, s3, prec, ctx.stream()); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; - mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream()); + mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01, + ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, + ne03, ne3, s03, s13, s3, prec, ctx.stream()); } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); @@ -262,27 +294,34 @@ void ggml_cuda_op_mul_mat_vec( const int64_t stride_row = ne00; const int64_t nchannels_x = 1; const int64_t nchannels_y = 1; + const int64_t nchannels_dst = 1; const int64_t stride_channel_x = 0; const int64_t stride_channel_y = 0; const int64_t stride_channel_dst = 0; const int64_t nsamples_x = 1; - const int64_t nsamples_y = 1; + const int64_t nsamples_dst = 1; const int64_t stride_sample_x = 0; const int64_t stride_sample_y = 0; const int64_t stride_sample_dst = 0; switch (src0->type) { + case GGML_TYPE_F32: { + const float * src0_d = (const float *) src0_dd_i; + mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); + } break; case GGML_TYPE_F16: { const half * src0_d = (const half *) src0_dd_i; - mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, - nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); + mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; - mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, - nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); + mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); diff --git a/ggml/src/ggml-cuda/mmv.cuh b/ggml/src/ggml-cuda/mmv.cuh index 78a1cd4a6..756e7e1cc 100644 --- a/ggml/src/ggml-cuda/mmv.cuh +++ b/ggml/src/ggml-cuda/mmv.cuh @@ -3,7 +3,7 @@ // maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available #define MMV_MAX_ROWS 512 -void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); +void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); void ggml_cuda_op_mul_mat_vec( ggml_backend_cuda_context & ctx, diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index eef8585a7..dc7adf509 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -1,50 +1,57 @@ #include "mmvq.cuh" +#include "quantize.cuh" #include "vecdotq.cuh" +#include + typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs); static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) { - return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 : - type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 : - type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 : - type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 : - type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 : - type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 : - type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 : - type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 : - type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 : - type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 : - type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 : - type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 : - type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 : - type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 : - type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 : - type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 : - type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 : - type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 : - type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 : - nullptr; + switch (type) { + case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1; + case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1; + case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1; + case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1; + case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1; + case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1; + case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1; + case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1; + case GGML_TYPE_Q5_K: return vec_dot_q5_K_q8_1; + case GGML_TYPE_Q6_K: return vec_dot_q6_K_q8_1; + case GGML_TYPE_IQ2_XXS: return vec_dot_iq2_xxs_q8_1; + case GGML_TYPE_IQ2_XS: return vec_dot_iq2_xs_q8_1; + case GGML_TYPE_IQ2_S: return vec_dot_iq2_s_q8_1; + case GGML_TYPE_IQ3_XXS: return vec_dot_iq3_xxs_q8_1; + case GGML_TYPE_IQ1_S: return vec_dot_iq1_s_q8_1; + case GGML_TYPE_IQ1_M: return vec_dot_iq1_m_q8_1; + case GGML_TYPE_IQ4_NL: return vec_dot_iq4_nl_q8_1; + case GGML_TYPE_IQ4_XS: return vec_dot_iq4_xs_q8_1; + case GGML_TYPE_IQ3_S: return vec_dot_iq3_s_q8_1; + default: return nullptr; + } } static constexpr __device__ int get_vdr_mmvq(ggml_type type) { - return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ : - type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ : - type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ : - type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ : - type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ : - type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ : - type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ : - type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ : - type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ : - type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ : - type == GGML_TYPE_IQ2_XXS ? VDR_IQ2_XXS_Q8_1_MMVQ : - type == GGML_TYPE_IQ2_XS ? VDR_IQ2_XS_Q8_1_MMVQ : - type == GGML_TYPE_IQ2_S ? VDR_IQ2_S_Q8_1_MMVQ : - type == GGML_TYPE_IQ3_XXS ? VDR_IQ3_XXS_Q8_1_MMVQ : - type == GGML_TYPE_IQ3_S ? VDR_IQ3_S_Q8_1_MMVQ : - type == GGML_TYPE_IQ4_NL ? VDR_IQ4_NL_Q8_1_MMVQ : - type == GGML_TYPE_IQ4_XS ? VDR_IQ4_XS_Q8_1_MMVQ : - 1; + switch (type) { + case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ; + case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ; + case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ; + case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ; + case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ; + case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ; + case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ; + case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ; + case GGML_TYPE_Q5_K: return VDR_Q5_K_Q8_1_MMVQ; + case GGML_TYPE_Q6_K: return VDR_Q6_K_Q8_1_MMVQ; + case GGML_TYPE_IQ2_XXS: return VDR_IQ2_XXS_Q8_1_MMVQ; + case GGML_TYPE_IQ2_XS: return VDR_IQ2_XS_Q8_1_MMVQ; + case GGML_TYPE_IQ2_S: return VDR_IQ2_S_Q8_1_MMVQ; + case GGML_TYPE_IQ3_XXS: return VDR_IQ3_XXS_Q8_1_MMVQ; + case GGML_TYPE_IQ3_S: return VDR_IQ3_S_Q8_1_MMVQ; + case GGML_TYPE_IQ4_NL: return VDR_IQ4_NL_Q8_1_MMVQ; + case GGML_TYPE_IQ4_XS: return VDR_IQ4_XS_Q8_1_MMVQ; + default: return 1; + } } enum mmvq_parameter_table_id { @@ -73,9 +80,9 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) { return MMVQ_PARAMETERS_GENERIC; } -static constexpr __host__ __device__ int calc_nwarps(int ncols_y, mmvq_parameter_table_id table_id) { +static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) { if (table_id == MMVQ_PARAMETERS_GENERIC) { - switch (ncols_y) { + switch (ncols_dst) { case 1: case 2: case 3: @@ -90,7 +97,7 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_y, mmvq_paramete return 1; } } else if (table_id == MMVQ_PARAMETERS_GCN) { - switch (ncols_y) { + switch (ncols_dst) { case 1: case 2: case 3: @@ -107,9 +114,9 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_y, mmvq_paramete return 1; } -static constexpr __host__ __device__ int calc_rows_per_block(int ncols_y, int table_id) { +static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) { if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) { - switch (ncols_y) { + switch (ncols_dst) { case 1: return 1; case 2: @@ -127,19 +134,21 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_y, int ta return 1; } -template +template // tell the compiler to use as many registers as it wants, see nwarps definition below -__launch_bounds__(calc_nwarps(ncols_y, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) +__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { + const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst, + const int ncols_x, const int nchannels_y, const int stride_row_x, const int stride_col_y, const int stride_col_dst, + const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { constexpr int qk = ggml_cuda_type_traits::qk; constexpr int qi = ggml_cuda_type_traits::qi; constexpr int vdr = get_vdr_mmvq(type); constexpr mmvq_parameter_table_id table_id = get_device_table_id(); - constexpr int nwarps = calc_nwarps(ncols_y, table_id); - constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_y, table_id); + constexpr int nwarps = calc_nwarps(ncols_dst, table_id); + constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); @@ -147,13 +156,21 @@ static __global__ void mul_mat_vec_q( const int tid = warp_size*threadIdx.y + threadIdx.x; const int row0 = rows_per_cuda_block*blockIdx.x; const int blocks_per_row_x = ncols_x / qk; - const int blocks_per_col_y = nrows_y / QK8_1; constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; - // partial sum for each thread - float tmp[ncols_y][rows_per_cuda_block] = {{0.0f}}; + // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1. + const int channel_dst = blockIdx.y; + const int channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : channel_dst / channel_ratio; + const int channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst; + const int sample_dst = blockIdx.z; + const int sample_x = sample_dst / sample_ratio; + const int sample_y = sample_dst; - const block_q8_1 * y = (const block_q8_1 *) vy; + // partial sum for each thread + float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}}; + + const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y; + const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x; for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx @@ -162,18 +179,19 @@ static __global__ void mul_mat_vec_q( const int kqs = vdr * (tid % (qi/vdr)); #pragma unroll - for (int j = 0; j < ncols_y; ++j) { + for (int j = 0; j < ncols_dst; ++j) { #pragma unroll for (int i = 0; i < rows_per_cuda_block; ++i) { - tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs); + tmp[j][i] += vec_dot_q_cuda( + vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs); } } } - __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][warp_size]; + __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; if (threadIdx.y > 0) { #pragma unroll - for (int j = 0; j < ncols_y; ++j) { + for (int j = 0; j < ncols_dst; ++j) { #pragma unroll for (int i = 0; i < rows_per_cuda_block; ++i) { tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i]; @@ -185,9 +203,11 @@ static __global__ void mul_mat_vec_q( return; } + dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0; + // sum up partial sums and write back result #pragma unroll - for (int j = 0; j < ncols_y; ++j) { + for (int j = 0; j < ncols_dst; ++j) { #pragma unroll for (int i = 0; i < rows_per_cuda_block; ++i) { #pragma unroll @@ -197,88 +217,121 @@ static __global__ void mul_mat_vec_q( tmp[j][i] = warp_reduce_sum(tmp[j][i]); } - if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < (unsigned)nrows_dst)) { - dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x]; + if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + int(threadIdx.x) < stride_col_dst)) { + dst[j*stride_col_dst + threadIdx.x] = tmp[j][threadIdx.x]; } } - - GGML_UNUSED(nrows_x); } -static std::pair calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) { - const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_y, table_id) - 1) / calc_rows_per_block(ncols_y, table_id); - const dim3 block_nums(nblocks, 1, 1); - const dim3 block_dims(warp_size, calc_nwarps(ncols_y, table_id), 1); +static std::pair calc_launch_params( + const int ncols_dst, const int nrows_x, const int nchannels_y, const int nsamples_y, + const int warp_size, const mmvq_parameter_table_id table_id) { + const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id); + const dim3 block_nums(nblocks, nchannels_y, nsamples_y); + const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1); return {block_nums, block_dims}; } template -static void mul_mat_vec_q_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { +static void mul_mat_vec_q_switch_ncols_dst( + const void * vx, const void * vy, const int32_t * ids, float * dst, + const int ncols_x, const int nrows_x, const int ncols_dst, + const int stride_row_x, const int stride_col_y, const int stride_col_dst, + const int nchannels_x, const int nchannels_y, const int nchannels_dst, + const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + cudaStream_t stream) { GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); - GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE); + GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE); + + const int channel_ratio = nchannels_dst / nchannels_x; + const int sample_ratio = nsamples_dst / nsamples_x; const int device = ggml_cuda_get_device(); const int warp_size = ggml_cuda_info().devices[device].warp_size; const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc); - switch (ncols_y) { + GGML_ASSERT(!ids || ncols_dst == 1); + switch (ncols_dst) { case 1: { - constexpr int c_ncols_y = 1; - std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id); - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + constexpr int c_ncols_dst = 1; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q<<>> + (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); break; } case 2: { - constexpr int c_ncols_y = 2; - std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id); - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + constexpr int c_ncols_dst = 2; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q<<>> + (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); break; } case 3: { - constexpr int c_ncols_y = 3; - std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id); - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + constexpr int c_ncols_dst = 3; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q<<>> + (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); break; } case 4: { - constexpr int c_ncols_y = 4; - std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id); - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + constexpr int c_ncols_dst = 4; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q<<>> + (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); break; } case 5: { - constexpr int c_ncols_y = 5; - std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id); - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + constexpr int c_ncols_dst = 5; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q<<>> + (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); break; } case 6: { - constexpr int c_ncols_y = 6; - std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id); - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + constexpr int c_ncols_dst = 6; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q<<>> + (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); break; } case 7: { - constexpr int c_ncols_y = 7; - std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id); - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + constexpr int c_ncols_dst = 7; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q<<>> + (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); break; } case 8: { - constexpr int c_ncols_y = 8; - std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id); - mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); + constexpr int c_ncols_dst = 8; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q<<>> + (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); break; } default: @@ -287,137 +340,224 @@ static void mul_mat_vec_q_cuda( } } -static void mul_mat_vec_q4_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); +static void mul_mat_vec_q_switch_type( + const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, float * dst, + const int ncols_x, const int nrows_x, const int ncols_dst, + const int stride_row_x, const int stride_col_y, const int stride_col_dst, + const int nchannels_x, const int nchannels_y, const int nchannels_dst, + const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + cudaStream_t stream) { + switch (type_x) { + case GGML_TYPE_Q4_0: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); + break; + case GGML_TYPE_Q4_1: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); + break; + case GGML_TYPE_Q5_0: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); + break; + case GGML_TYPE_Q5_1: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); + break; + case GGML_TYPE_Q8_0: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); + break; + case GGML_TYPE_Q2_K: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); + break; + case GGML_TYPE_Q3_K: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); + break; + case GGML_TYPE_Q4_K: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); + break; + case GGML_TYPE_Q5_K: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); + break; + case GGML_TYPE_Q6_K: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); + break; + case GGML_TYPE_IQ2_XXS: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); + break; + case GGML_TYPE_IQ2_XS: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); + break; + case GGML_TYPE_IQ2_S: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); + break; + case GGML_TYPE_IQ3_XXS: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); + break; + case GGML_TYPE_IQ1_S: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); + break; + case GGML_TYPE_IQ1_M: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); + break; + case GGML_TYPE_IQ4_NL: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); + break; + case GGML_TYPE_IQ4_XS: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); + break; + case GGML_TYPE_IQ3_S: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + stream); + break; + default: + GGML_ABORT("fatal error"); + break; + } } -static void mul_mat_vec_q4_1_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { +void ggml_cuda_mul_mat_vec_q( + ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { + GGML_ASSERT( src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID. - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} + GGML_TENSOR_BINARY_OP_LOCALS; -static void mul_mat_vec_q5_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + cudaStream_t stream = ctx.stream(); - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} + const size_t ts_src0 = ggml_type_size(src0->type); + const size_t ts_src1 = ggml_type_size(src1->type); + const size_t ts_dst = ggml_type_size(dst->type); -static void mul_mat_vec_q5_1_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + GGML_ASSERT( nb00 == ts_src0); + GGML_ASSERT( nb10 == ts_src1); + GGML_ASSERT( nb0 == ts_dst); + GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} + GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1. -static void mul_mat_vec_q8_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const float * src1_d = (const float *) src1->data; + const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; + float * dst_d = (float *) dst->data; - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} + // If src0 is a temporary compute buffer, clear any potential padding. + if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) { + const size_t size_data = ggml_nbytes(src0); + const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0); + if (size_alloc > size_data) { + GGML_ASSERT(ggml_is_contiguously_allocated(src0)); + GGML_ASSERT(!src0->view_src); + CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream)); + } + } -static void mul_mat_vec_q2_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING); + ggml_cuda_pool_alloc src1_q8_1(ctx.pool(), ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1); + { + const int64_t s11 = src1->nb[1] / ts_src1; + const int64_t s12 = src1->nb[2] / ts_src1; + const int64_t s13 = src1->nb[3] / ts_src1; + quantize_row_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream); + } - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} + const int64_t s01 = src0->nb[1] / ts_src0; + const int64_t s11 = ne10_padded / QK8_1; + const int64_t s1 = dst->nb[1] / ts_dst; + const int64_t s02 = src0->nb[2] / ts_src0; + const int64_t s2 = dst->nb[2] / ts_dst; + const int64_t s03 = src0->nb[3] / ts_src0; + const int64_t s3 = dst->nb[3] / ts_dst; -static void mul_mat_vec_q3_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { + const int64_t s12 = ne11*s11; + const int64_t s13 = ne12*s12; - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} + // For MUL_MAT_ID the memory layout is different than for MUL_MAT: + const int64_t ncols_dst = ids ? ne2 : ne1; + const int64_t nchannels_y = ids ? ne11 : ne12; + const int64_t nchannels_dst = ids ? ne1 : ne2; + const int64_t stride_col_dst = ids ? s2 : s1; + const int64_t stride_col_y = ids ? s12 : s11; + const int64_t stride_channel_dst = ids ? s1 : s2; + const int64_t stride_channel_y = ids ? s11 : s12; -static void mul_mat_vec_q4_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_q5_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_q6_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_iq2_xxs_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_iq2_xs_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_iq2_s_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_iq3_xxs_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_iq1_s_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_iq1_m_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_iq4_nl_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_iq4_xs_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); -} - -static void mul_mat_vec_iq3_s_q8_1_cuda( - const void * vx, const void * vy, float * dst, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); + mul_mat_vec_q_switch_type( + src0->data, src0->type, src1_q8_1.get(), ids_d, dst_d, ne00, + ne01, ncols_dst, s01, stride_col_y, stride_col_dst, + ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, + ne03, ne3, s03, s13, s3, stream); } void ggml_cuda_op_mul_mat_vec_q( @@ -440,68 +580,12 @@ void ggml_cuda_op_mul_mat_vec_q( // nrows_dst == nrows of the matrix that the kernel writes into const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; - switch (src0->type) { - case GGML_TYPE_Q4_0: - mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q4_1: - mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q5_0: - mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q5_1: - mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q8_0: - mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q2_K: - mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q3_K: - mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q4_K: - mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q5_K: - mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_Q6_K: - mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ2_XXS: - mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ2_XS: - mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ2_S: - mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ3_XXS: - mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ1_S: - mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ1_M: - mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ4_NL: - mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ4_XS: - mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - case GGML_TYPE_IQ3_S: - mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); - break; - default: - GGML_ABORT("fatal error"); - break; - } + const int stride_row_x = ne00 / ggml_blck_size(src0->type); + const int stride_col_y = src1_padded_row_size / QK8_1; + + mul_mat_vec_q_switch_type( + src0_dd_i, src0->type, src1_ddq_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream); GGML_UNUSED(src1); GGML_UNUSED(dst); diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh index d9e42fdd6..39dc7d33e 100644 --- a/ggml/src/ggml-cuda/mmvq.cuh +++ b/ggml/src/ggml-cuda/mmvq.cuh @@ -2,6 +2,9 @@ #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. +void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); + void ggml_cuda_op_mul_mat_vec_q( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 1702e4ce2..a0b03a740 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -1,30 +1,40 @@ #include "quantize.cuh" #include -static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) { - const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; +static __global__ void quantize_q8_1( + const float * __restrict__ x, void * __restrict__ vy, + const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int ne1, const int ne2) { + const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; - if (ix0 >= kx0_padded) { + if (i0 >= ne0) { return; } - const int64_t ix1 = blockIdx.y; + const int64_t i1 = blockIdx.y; + const int64_t i2 = blockIdx.z % ne2; + const int64_t i3 = blockIdx.z / ne2; - const int64_t i_padded = ix1*kx0_padded + ix0; + const int64_t & i00 = i0; + const int64_t & i01 = i1; + const int64_t & i02 = i2; + const int64_t & i03 = i3; + + const int64_t i_cont = ((i3*ne2 + i2) * ne1 + i1) * ne0 + i0; block_q8_1 * y = (block_q8_1 *) vy; - const int64_t ib = i_padded / QK8_1; // block index - const int64_t iqs = i_padded % QK8_1; // quant index + const int64_t ib = i_cont / QK8_1; // block index + const int64_t iqs = i_cont % QK8_1; // quant index - const float xi = ix0 < kx ? x[ix1*kx + ix0] : 0.0f; + const float xi = i0 < ne00 ? x[i03*s03 + i02*s02 + i01*s01 + i00] : 0.0f; float amax = fabsf(xi); float sum = xi; amax = warp_reduce_max(amax); - sum = warp_reduce_sum(sum); + sum = warp_reduce_sum(sum); - const float d = amax / 127; + const float d = amax / 127; const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); y[ib].qs[iqs] = q; @@ -39,29 +49,38 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest template static __global__ void quantize_mmq_q8_1( - const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) { + const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy, + const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int ne1, const int ne2) { constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32; constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32; - const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4; + const int64_t i0 = ((int64_t)blockDim.x*blockIdx.y + threadIdx.x)*4; - if (ix0 >= kx0_padded) { + if (i0 >= ne0) { return; } - const float4 * x4 = (const float4 *) x; + const int64_t i1 = blockIdx.x; + const int64_t i2 = blockIdx.z % ne2; + const int64_t i3 = blockIdx.z / ne2; - const int64_t ix1 = kx1*blockIdx.z + blockIdx.y; + const int64_t i00 = i0; + const int64_t i01 = ids ? ids[i1] : i1; + const int64_t i02 = i2; + const int64_t i03 = i3; + + const float4 * x4 = (const float4 *) x; block_q8_1_mmq * y = (block_q8_1_mmq *) vy; - const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel - const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel - const int64_t iqs = ix0 % (4*QK8_1); // quant index in block + const int64_t ib0 = blockIdx.z*((int64_t)gridDim.x*gridDim.y*blockDim.x/QK8_1); // first block of channel + const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.x; // block index in channel + const int64_t iqs = i0 % (4*QK8_1); // quant index in block // Load 4 floats per thread and calculate max. abs. value between them: - const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f); + const float4 xi = i0 < ne00 ? x4[(i03*s03 + i02*s02 + i01*s01 + i00)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f); float amax = fabsf(xi.x); amax = fmaxf(amax, fabsf(xi.y)); amax = fmaxf(amax, fabsf(xi.z)); @@ -77,7 +96,7 @@ static __global__ void quantize_mmq_q8_1( if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) { sum = xi.x + xi.y + xi.z + xi.w; - // Exchange calculate sum across vals_per_sum/4 threads. + // Calculate sums across vals_per_sum/4 threads. #pragma unroll for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) { sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE); @@ -127,40 +146,42 @@ static __global__ void quantize_mmq_q8_1( } void quantize_row_q8_1_cuda( - const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, - const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) { + const float * x, const int32_t * ids, void * vy, const ggml_type type_src0, + const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { + GGML_ASSERT(!ids); + GGML_ASSERT(ne0 % QK8_1 == 0); - GGML_ASSERT(kx0_padded % QK8_1 == 0); - - const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; - const dim3 num_blocks(block_num_x, kx1*channels, 1); + const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; + const dim3 num_blocks(block_num_x, ne1, ne2*ne3); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); - quantize_q8_1<<>>(x, vy, kx0, kx0_padded); - - GGML_UNUSED(type_x); + quantize_q8_1<<>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2); + GGML_UNUSED(type_src0); } void quantize_mmq_q8_1_cuda( - const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, - const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) { + const float * x, const int32_t * ids, void * vy, const ggml_type type_src0, + const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ne0 % (4*QK8_1) == 0); - GGML_ASSERT(kx0_padded % (4*QK8_1) == 0); - - const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ); - const dim3 num_blocks(block_num_x, kx1, channels); + // ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid: + const int64_t block_num_y = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ); + const dim3 num_blocks(ne1, block_num_y, ne2*ne3); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1); - switch (mmq_get_q8_1_ds_layout(type_x)) { + switch (mmq_get_q8_1_ds_layout(type_src0)) { case MMQ_Q8_1_DS_LAYOUT_D4: quantize_mmq_q8_1 - <<>>(x, vy, kx0, kx1, kx0_padded); + <<>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); break; case MMQ_Q8_1_DS_LAYOUT_DS4: quantize_mmq_q8_1 - <<>>(x, vy, kx0, kx1, kx0_padded); + <<>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); break; case MMQ_Q8_1_DS_LAYOUT_D2S6: quantize_mmq_q8_1 - <<>>(x, vy, kx0, kx1, kx0_padded); + <<>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); break; default: GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cuda/quantize.cuh b/ggml/src/ggml-cuda/quantize.cuh index 03bf322b9..725ab5244 100644 --- a/ggml/src/ggml-cuda/quantize.cuh +++ b/ggml/src/ggml-cuda/quantize.cuh @@ -12,13 +12,16 @@ static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access."); typedef void (*quantize_cuda_t)( - const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, - const ggml_type type_x, cudaStream_t stream); + const float * x, const int32_t * ids, void * vy, + ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03, + int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream); void quantize_row_q8_1_cuda( - const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, - const ggml_type type_x, cudaStream_t stream); + const float * x, const int32_t * ids, void * vy, + ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03, + int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream); void quantize_mmq_q8_1_cuda( - const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, - const ggml_type type_x, cudaStream_t stream); + const float * x, const int32_t * ids, void * vy, + ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03, + int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu index 37ee208c0..2d34b8360 100644 --- a/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -10,6 +10,8 @@ __global__ void __launch_bounds__(splitD, 2) float * __restrict__ dst, const int64_t L) { GGML_UNUSED(src1_nb0); GGML_UNUSED(src2_nb0); + + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); const int bidx = blockIdx.x; // split along B const int bidy = blockIdx.y; // split along D const int tid = threadIdx.x; @@ -44,16 +46,16 @@ __global__ void __launch_bounds__(splitD, 2) if (N == 16) { #pragma unroll for (size_t i = 0; i < splitD / 4; i += 2) { - float value = A_block[(wid * warpSize + i) * stride_A + wtid]; + float value = A_block[(wid * warp_size + i) * stride_A + wtid]; // todo: bank conflict // I am always confused with how to use the swizzling method to solve // bank conflit. Hoping somebody can tell me. - smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; + smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; } #pragma unroll for (size_t i = 0; i < splitD / 4; i += 2) { - float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid]; - smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; + float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid]; + smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value; } } diff --git a/ggml/src/ggml-cuda/sum.cu b/ggml/src/ggml-cuda/sum.cu index f9589080a..eb3d7cdba 100644 --- a/ggml/src/ggml-cuda/sum.cu +++ b/ggml/src/ggml-cuda/sum.cu @@ -31,7 +31,7 @@ void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguously_allocated(src0)); const float * src0_d = (const float *) src0->data; float * dst_d = (float *) dst->data; diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu new file mode 100644 index 000000000..fb26abeb0 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu index 80108615a..dc1682902 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 1, 8); -DECL_FATTN_MMA_F16_CASE(80, 1, 8); -DECL_FATTN_MMA_F16_CASE(96, 1, 8); -DECL_FATTN_MMA_F16_CASE(112, 1, 8); -DECL_FATTN_MMA_F16_CASE(128, 1, 8); -DECL_FATTN_MMA_F16_CASE(256, 1, 8); +DECL_FATTN_MMA_F16_CASE(64, 64, 1, 8); +DECL_FATTN_MMA_F16_CASE(80, 80, 1, 8); +DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8); +DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8); +DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8); +DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu index 66161c0ab..9d3cfd8ed 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 16, 1); -DECL_FATTN_MMA_F16_CASE(80, 16, 1); -DECL_FATTN_MMA_F16_CASE(96, 16, 1); -DECL_FATTN_MMA_F16_CASE(112, 16, 1); -DECL_FATTN_MMA_F16_CASE(128, 16, 1); -DECL_FATTN_MMA_F16_CASE(256, 16, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 16, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 16, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 16, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 16, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 16, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 16, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu index ee88c72aa..2e1883af4 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 16, 2); -DECL_FATTN_MMA_F16_CASE(80, 16, 2); -DECL_FATTN_MMA_F16_CASE(96, 16, 2); -DECL_FATTN_MMA_F16_CASE(112, 16, 2); -DECL_FATTN_MMA_F16_CASE(128, 16, 2); -DECL_FATTN_MMA_F16_CASE(256, 16, 2); +DECL_FATTN_MMA_F16_CASE(64, 64, 16, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 16, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 16, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 16, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 16, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 16, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu index d888a5a42..2074e954a 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 16, 4); -DECL_FATTN_MMA_F16_CASE(80, 16, 4); -DECL_FATTN_MMA_F16_CASE(96, 16, 4); -DECL_FATTN_MMA_F16_CASE(112, 16, 4); -DECL_FATTN_MMA_F16_CASE(128, 16, 4); -DECL_FATTN_MMA_F16_CASE(256, 16, 4); +DECL_FATTN_MMA_F16_CASE(64, 64, 16, 4); +DECL_FATTN_MMA_F16_CASE(80, 80, 16, 4); +DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4); +DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4); +DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4); +DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu new file mode 100644 index 000000000..f011a208c --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu index d93a2d08e..24c64cf00 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 2, 4); -DECL_FATTN_MMA_F16_CASE(80, 2, 4); -DECL_FATTN_MMA_F16_CASE(96, 2, 4); -DECL_FATTN_MMA_F16_CASE(112, 2, 4); -DECL_FATTN_MMA_F16_CASE(128, 2, 4); -DECL_FATTN_MMA_F16_CASE(256, 2, 4); +DECL_FATTN_MMA_F16_CASE(64, 64, 2, 4); +DECL_FATTN_MMA_F16_CASE(80, 80, 2, 4); +DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4); +DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4); +DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4); +DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu index 617464c94..163b1d939 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 2, 8); -DECL_FATTN_MMA_F16_CASE(80, 2, 8); -DECL_FATTN_MMA_F16_CASE(96, 2, 8); -DECL_FATTN_MMA_F16_CASE(112, 2, 8); -DECL_FATTN_MMA_F16_CASE(128, 2, 8); -DECL_FATTN_MMA_F16_CASE(256, 2, 8); +DECL_FATTN_MMA_F16_CASE(64, 64, 2, 8); +DECL_FATTN_MMA_F16_CASE(80, 80, 2, 8); +DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8); +DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8); +DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8); +DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu index 970d2b686..0543532ea 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 32, 1); -DECL_FATTN_MMA_F16_CASE(80, 32, 1); -DECL_FATTN_MMA_F16_CASE(96, 32, 1); -DECL_FATTN_MMA_F16_CASE(112, 32, 1); -DECL_FATTN_MMA_F16_CASE(128, 32, 1); -DECL_FATTN_MMA_F16_CASE(256, 32, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 32, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 32, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 32, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 32, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 32, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 32, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu index 65cd377c3..407b6cf4c 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 32, 2); -DECL_FATTN_MMA_F16_CASE(80, 32, 2); -DECL_FATTN_MMA_F16_CASE(96, 32, 2); -DECL_FATTN_MMA_F16_CASE(112, 32, 2); -DECL_FATTN_MMA_F16_CASE(128, 32, 2); -DECL_FATTN_MMA_F16_CASE(256, 32, 2); +DECL_FATTN_MMA_F16_CASE(64, 64, 32, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 32, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 32, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 32, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 32, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 32, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu new file mode 100644 index 000000000..f5fd0e236 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu index f4a8bf348..5e4668502 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 4, 2); -DECL_FATTN_MMA_F16_CASE(80, 4, 2); -DECL_FATTN_MMA_F16_CASE(96, 4, 2); -DECL_FATTN_MMA_F16_CASE(112, 4, 2); -DECL_FATTN_MMA_F16_CASE(128, 4, 2); -DECL_FATTN_MMA_F16_CASE(256, 4, 2); +DECL_FATTN_MMA_F16_CASE(64, 64, 4, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 4, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 4, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 4, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 4, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 4, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu index de191a8ab..1ada657f1 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 4, 4); -DECL_FATTN_MMA_F16_CASE(80, 4, 4); -DECL_FATTN_MMA_F16_CASE(96, 4, 4); -DECL_FATTN_MMA_F16_CASE(112, 4, 4); -DECL_FATTN_MMA_F16_CASE(128, 4, 4); -DECL_FATTN_MMA_F16_CASE(256, 4, 4); +DECL_FATTN_MMA_F16_CASE(64, 64, 4, 4); +DECL_FATTN_MMA_F16_CASE(80, 80, 4, 4); +DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4); +DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4); +DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4); +DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu index e8cb0e1b3..bad296b41 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 4, 8); -DECL_FATTN_MMA_F16_CASE(80, 4, 8); -DECL_FATTN_MMA_F16_CASE(96, 4, 8); -DECL_FATTN_MMA_F16_CASE(112, 4, 8); -DECL_FATTN_MMA_F16_CASE(128, 4, 8); -DECL_FATTN_MMA_F16_CASE(256, 4, 8); +DECL_FATTN_MMA_F16_CASE(64, 64, 4, 8); +DECL_FATTN_MMA_F16_CASE(80, 80, 4, 8); +DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8); +DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8); +DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8); +DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu index a532e9629..0d7a9c728 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 64, 1); -DECL_FATTN_MMA_F16_CASE(80, 64, 1); -DECL_FATTN_MMA_F16_CASE(96, 64, 1); -DECL_FATTN_MMA_F16_CASE(112, 64, 1); -DECL_FATTN_MMA_F16_CASE(128, 64, 1); -DECL_FATTN_MMA_F16_CASE(256, 64, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 64, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 64, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 64, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 64, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 64, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 64, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu index bf25181aa..9d5a9976f 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 8, 1); -DECL_FATTN_MMA_F16_CASE(80, 8, 1); -DECL_FATTN_MMA_F16_CASE(96, 8, 1); -DECL_FATTN_MMA_F16_CASE(112, 8, 1); -DECL_FATTN_MMA_F16_CASE(128, 8, 1); -DECL_FATTN_MMA_F16_CASE(256, 8, 1); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu index 378c132e6..a6e6f093d 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 8, 2); -DECL_FATTN_MMA_F16_CASE(80, 8, 2); -DECL_FATTN_MMA_F16_CASE(96, 8, 2); -DECL_FATTN_MMA_F16_CASE(112, 8, 2); -DECL_FATTN_MMA_F16_CASE(128, 8, 2); -DECL_FATTN_MMA_F16_CASE(256, 8, 2); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 2); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu index 372641be9..86d4ffae2 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 8, 4); -DECL_FATTN_MMA_F16_CASE(80, 8, 4); -DECL_FATTN_MMA_F16_CASE(96, 8, 4); -DECL_FATTN_MMA_F16_CASE(112, 8, 4); -DECL_FATTN_MMA_F16_CASE(128, 8, 4); -DECL_FATTN_MMA_F16_CASE(256, 8, 4); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 4); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 4); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu index 9ff5968b6..680a13ca6 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu @@ -2,9 +2,9 @@ #include "../fattn-mma-f16.cuh" -DECL_FATTN_MMA_F16_CASE(64, 8, 8); -DECL_FATTN_MMA_F16_CASE(80, 8, 8); -DECL_FATTN_MMA_F16_CASE(96, 8, 8); -DECL_FATTN_MMA_F16_CASE(112, 8, 8); -DECL_FATTN_MMA_F16_CASE(128, 8, 8); -DECL_FATTN_MMA_F16_CASE(256, 8, 8); +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 8); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 8); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index dd373a09d..3428113dc 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -18,7 +18,7 @@ SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_f """ -SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {ncols1}, {ncols2});\n" +SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n" TYPES_MMQ = [ "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", @@ -57,18 +57,21 @@ for vkq_size in [16, 32]: with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f: f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v)) -for ncols in [8, 16, 32, 64, 128]: - for ncols2 in [1, 2, 4, 8]: +for ncols in [8, 16, 32, 64]: + for ncols2 in [1, 2, 4, 8, 16]: + if ncols2 > ncols: + continue ncols1 = ncols // ncols2 - if ncols == 128: - continue # Too much register pressure. with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f: f.write(SOURCE_FATTN_MMA_START) - for head_size in [64, 80, 96, 112, 128, 256]: - if ncols == 128 and head_size == 256: - continue # Needs too much shared memory. - f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size=head_size)) + for head_size_kq in [64, 80, 96, 112, 128, 256, 576]: + if head_size_kq != 576 and ncols2 == 16: + continue + if head_size_kq == 576 and ncols2 != 16: + continue + head_size_v = head_size_kq if head_size_kq != 576 else 512 + f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v)) for type in TYPES_MMQ: with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index ec5773e01..2c0375fbe 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -23,6 +23,12 @@ static __device__ __forceinline__ float op_gelu(float x) { return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } +static __device__ __forceinline__ float op_gelu_erf(float x) { + const float SQRT_2_INV = 0.70710678118654752440084436210484f; + + return 0.5f*x*(1.0f + erff(x*SQRT_2_INV)); +} + static __device__ __forceinline__ float op_gelu_quick(float x) { const float GELU_QUICK_COEF = -1.702f; @@ -134,6 +140,10 @@ void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst); } +void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst); } diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 940a1feed..6686fc17e 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -30,6 +30,8 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 40091a0ef..ba195e1d1 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -1,3 +1,5 @@ +#pragma once + #include "common.cuh" #include diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 420b41b8d..1a28831b7 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -71,6 +71,8 @@ #define cudaLaunchHostFunc hipLaunchHostFunc #define cudaMalloc hipMalloc #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) +#define cudaMallocManaged hipMallocManaged +#define cudaMemAdvise hipMemAdvise #define cudaMemcpy hipMemcpy #define cudaMemcpyAsync hipMemcpyAsync #define cudaMemcpyPeerAsync hipMemcpyPeerAsync diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index e3762649f..e29df9856 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -89,10 +89,6 @@ endif() add_compile_definitions(GGML_USE_HIP) -if (GGML_HIP_UMA) - add_compile_definitions(GGML_HIP_UMA) -endif() - if (GGML_CUDA_FORCE_MMQ) add_compile_definitions(GGML_CUDA_FORCE_MMQ) endif() @@ -117,6 +113,10 @@ if (GGML_HIP_ROCWMMA_FATTN) add_compile_definitions(GGML_HIP_ROCWMMA_FATTN) endif() +if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7.0) + add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12) +endif() + if (NOT GGML_CUDA_FA) add_compile_definitions(GGML_CUDA_NO_FA) endif() diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index a19cfb14e..6dc5ce0d9 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -32,6 +32,8 @@ extern "C" { #endif +void ggml_print_backtrace(void); + #ifndef MIN # define MIN(a, b) ((a) < (b) ? (a) : (b)) #endif @@ -386,7 +388,7 @@ GGML_API void ggml_aligned_free(void * ptr, size_t size); return r; } -#elif defined(__riscv) && defined(GGML_RV_ZFH) +#elif defined(__riscv) && defined(__riscv_zfhmin) static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) { float f; diff --git a/ggml/src/ggml-metal/CMakeLists.txt b/ggml/src/ggml-metal/CMakeLists.txt index e22232780..77187efc1 100644 --- a/ggml/src/ggml-metal/CMakeLists.txt +++ b/ggml/src/ggml-metal/CMakeLists.txt @@ -44,21 +44,22 @@ if (GGML_METAL_EMBED_LIBRARY) set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp") add_custom_command( - OUTPUT ${METALLIB_EMBED_ASM} + OUTPUT "${METALLIB_EMBED_ASM}" COMMAND echo "Embedding Metal library" - COMMAND sed -e '/__embed_ggml-common.h__/r ${METALLIB_COMMON}' -e '/__embed_ggml-common.h__/d' < ${METALLIB_SOURCE} > ${METALLIB_SOURCE_EMBED_TMP} - COMMAND sed -e '/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}' -e '/\#include \"ggml-metal-impl.h\"/d' < ${METALLIB_SOURCE_EMBED_TMP} > ${METALLIB_SOURCE_EMBED} - COMMAND echo ".section __DATA,__ggml_metallib" > ${METALLIB_EMBED_ASM} - COMMAND echo ".globl _ggml_metallib_start" >> ${METALLIB_EMBED_ASM} - COMMAND echo "_ggml_metallib_start:" >> ${METALLIB_EMBED_ASM} - COMMAND echo ".incbin \\\"${METALLIB_SOURCE_EMBED}\\\"" >> ${METALLIB_EMBED_ASM} - COMMAND echo ".globl _ggml_metallib_end" >> ${METALLIB_EMBED_ASM} - COMMAND echo "_ggml_metallib_end:" >> ${METALLIB_EMBED_ASM} + COMMAND sed -e "/__embed_ggml-common.h__/r ${METALLIB_COMMON}" -e "/__embed_ggml-common.h__/d" < "${METALLIB_SOURCE}" > "${METALLIB_SOURCE_EMBED_TMP}" + COMMAND sed -e "/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}" -e "/\#include \"ggml-metal-impl.h\"/d" < "${METALLIB_SOURCE_EMBED_TMP}" > "${METALLIB_SOURCE_EMBED}" + COMMAND echo ".section __DATA,__ggml_metallib" > "${METALLIB_EMBED_ASM}" + COMMAND echo ".globl _ggml_metallib_start" >> "${METALLIB_EMBED_ASM}" + COMMAND echo "_ggml_metallib_start:" >> "${METALLIB_EMBED_ASM}" + COMMAND echo .incbin "\"${METALLIB_SOURCE_EMBED}\"" >> "${METALLIB_EMBED_ASM}" + COMMAND echo ".globl _ggml_metallib_end" >> "${METALLIB_EMBED_ASM}" + COMMAND echo "_ggml_metallib_end:" >> "${METALLIB_EMBED_ASM}" DEPENDS ../ggml-common.h ggml-metal.metal ggml-metal-impl.h COMMENT "Generate assembly for embedded Metal library" + VERBATIM ) - target_sources(ggml-metal PRIVATE ${METALLIB_EMBED_ASM}) + target_sources(ggml-metal PRIVATE "${METALLIB_EMBED_ASM}") else() if (GGML_METAL_SHADER_DEBUG) # custom command to do the following: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 8721b272d..17eab976f 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -207,6 +207,10 @@ typedef struct { float attn_factor; float beta_fast; float beta_slow; + int32_t sect_0; + int32_t sect_1; + int32_t sect_2; + int32_t sect_3; } ggml_metal_kargs_rope; typedef struct { @@ -299,21 +303,42 @@ typedef struct { } ggml_metal_kargs_mul_mv_ext; typedef struct { - int32_t nei0; - int32_t nei1; - uint64_t nbi1; + int32_t ne10; + int32_t ne11; // n_expert_used (bcast) + uint64_t nb11; + uint64_t nb12; + int32_t neh11; // n_tokens + uint64_t nbh11; + int32_t ne20; // n_expert_used + uint64_t nb21; +} ggml_metal_kargs_mul_mm_id_map0; + +typedef struct { + int32_t ne20; // n_expert_used + int32_t neh0; + int32_t neh1; + uint64_t nbh1; + uint64_t nbh2; + int32_t ne0; + uint64_t nb1; + uint64_t nb2; +} ggml_metal_kargs_mul_mm_id_map1; + +typedef struct { int32_t ne00; int32_t ne02; uint64_t nb01; uint64_t nb02; - int32_t ne11; - int32_t ne12; - int32_t ne13; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - int32_t ne0; - int32_t ne1; + uint64_t nb03; + int32_t neh12; + uint64_t nbh10; + uint64_t nbh11; + uint64_t nbh12; + uint64_t nbh13; + int32_t neh0; + int32_t neh1; + int16_t r2; + int16_t r3; } ggml_metal_kargs_mul_mm_id; typedef struct { diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 9f1c6c6cc..bc93bc633 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -44,8 +44,8 @@ static struct ggml_backend_device g_ggml_backend_metal_device; // note: assumes single GPU device - the default one // TODO: support multiple GPU devices static struct ggml_backend_metal_device_context { - id mtl_device; - int mtl_device_ref_count; + id mtl_device; + int mtl_device_ref_count; id mtl_library; bool has_simdgroup_reduction; @@ -149,6 +149,8 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_SIGMOID, GGML_METAL_KERNEL_TYPE_GELU, GGML_METAL_KERNEL_TYPE_GELU_4, + GGML_METAL_KERNEL_TYPE_GELU_ERF, + GGML_METAL_KERNEL_TYPE_GELU_ERF_4, GGML_METAL_KERNEL_TYPE_GELU_QUICK, GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, GGML_METAL_KERNEL_TYPE_SILU, @@ -306,30 +308,36 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, - GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, + GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, + GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, + GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, + GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, GGML_METAL_KERNEL_TYPE_IM2COL_F16, @@ -354,6 +362,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, @@ -362,6 +371,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, @@ -370,6 +380,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, @@ -378,6 +389,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, @@ -386,6 +398,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, @@ -394,6 +407,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, @@ -402,6 +416,21 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, @@ -430,6 +459,13 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, GGML_METAL_KERNEL_TYPE_SET_I32, GGML_METAL_KERNEL_TYPE_SET_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, @@ -460,6 +496,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_SQRT, GGML_METAL_KERNEL_TYPE_SIN, GGML_METAL_KERNEL_TYPE_COS, + GGML_METAL_KERNEL_TYPE_NEG, GGML_METAL_KERNEL_TYPE_SUM_ROWS, GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, @@ -468,7 +505,264 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_COUNT }; +// +// ggml_metal_heap +// + +struct ggml_metal_heap { + // number of times the heap was unused + int n_unused; + + // total number of buffer allocations in this heap across all computes + int64_t n_alloc; + + // current offset in the heap - we reset this after each node in order to reuse the memory + size_t offs; + + // the currently allocated MTLBuffer objects in this heap + id obj; + + NSMutableArray * bufs; +}; + +static struct ggml_metal_heap * ggml_metal_heap_init(id device, size_t size) { + struct ggml_metal_heap * heap = calloc(1, sizeof(struct ggml_metal_heap)); + + MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init]; + desc.storageMode = MTLStorageModePrivate; + desc.cpuCacheMode = MTLCPUCacheModeDefaultCache; + desc.type = MTLHeapTypePlacement; + desc.size = size; + + heap->n_unused = 0; + heap->n_alloc = 0; + + heap->obj = [device newHeapWithDescriptor:desc]; + if (!heap->obj) { + GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size); + + free(heap); + + return false; + } + + [desc release]; + + heap->bufs = [[NSMutableArray alloc] init]; + + return heap; +} + +static void ggml_metal_heap_reset(struct ggml_metal_heap * heap) { + heap->offs = 0; + + // count how many graph computes the heap ended up being unused + if ([heap->bufs count] > 0) { + heap->n_unused = 0; + } else { + heap->n_unused++; + } + + for (id buf in heap->bufs) { + [buf release]; + } + [heap->bufs removeAllObjects]; + + // tell the OS that it can reuse this memory if needed + // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc + [heap->obj setPurgeableState:MTLPurgeableStateVolatile]; +} + +static void ggml_metal_heap_free(struct ggml_metal_heap * heap) { + if (heap == nil) { + return; + } + + ggml_metal_heap_reset(heap); + + [heap->obj release]; + [heap->bufs release]; + + free(heap); +} + +@interface ggml_metal_heap_ptr : NSObject + +@property (nonatomic, assign) struct ggml_metal_heap * data; + +@end + +@implementation ggml_metal_heap_ptr +@end + +// +// ggml_metal_mem_pool +// + +struct ggml_metal_mem_pool { + id device; + + int n_heaps; // total number of heaps ever created (including those that were removed) + + NSMutableArray * heaps; + NSMutableArray * heaps_to_remove; +}; + +static struct ggml_metal_mem_pool * ggml_metal_mem_pool_init(void) { + struct ggml_metal_mem_pool * mem_pool = calloc(1, sizeof(struct ggml_metal_mem_pool)); + + mem_pool->n_heaps = 0; + + mem_pool->heaps = [[NSMutableArray alloc] init]; + mem_pool->heaps_to_remove = [[NSMutableArray alloc] init]; + + return mem_pool; +} + +static void ggml_metal_mem_pool_free(struct ggml_metal_mem_pool * mem_pool) { + GGML_LOG_DEBUG("%s: freeing memory pool, num heaps = %zu (total = %d)\n", __func__, [mem_pool->heaps count], mem_pool->n_heaps); + + size_t size_all = 0; + size_t size_cur = 0; + + for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) { + GGML_LOG_DEBUG("%s: heap: %p\n", __func__, (void *) ptr.data); + GGML_LOG_DEBUG("%s: n_alloc: %" PRId64 "\n", __func__, ptr.data->n_alloc); + GGML_LOG_DEBUG("%s: n_unused: %d\n", __func__, ptr.data->n_unused); + GGML_LOG_DEBUG("%s: size: %.2f MiB\n", __func__, [ptr.data->obj size] / 1024.0 / 1024.0); + GGML_LOG_DEBUG("%s: bufs: %zu\n", __func__, [ptr.data->bufs count]); + + if ([ptr.data->bufs count] > 0) { + size_cur += [ptr.data->obj size]; + } + size_all += [ptr.data->obj size]; + + ggml_metal_heap_free(ptr.data); + [ptr release]; + } + [mem_pool->heaps release]; + [mem_pool->heaps_to_remove release]; + + if (size_all > 0) { + GGML_LOG_DEBUG("%s: size_all: %.2f MiB\n", __func__, size_all / 1024.0 / 1024.0); + GGML_LOG_DEBUG("%s: size_cur: %.2f MiB\n", __func__, size_cur / 1024.0 / 1024.0); + } + + free(mem_pool); +} + +static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) { + for (NSUInteger i = 0; i < [mem_pool->heaps count]; i++) { + ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:i]; + + struct ggml_metal_heap * heap = ptr.data; + ggml_metal_heap_reset(heap); + + // if the heap hasn't been used for a while, remove it + if (heap->n_unused >= 128) { + [mem_pool->heaps_to_remove addObject:@(i)]; + } + } + + if (mem_pool->heaps_to_remove.count > 0) { + // remove in reverse order + for (NSUInteger i = [mem_pool->heaps_to_remove count] - 1; ; --i) { + NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue]; + ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index]; + + struct ggml_metal_heap * heap = ptr.data; + ggml_metal_heap_free(heap); + + [mem_pool->heaps removeObjectAtIndex:index]; + [ptr release]; + + if (i == 0) { + break; + } + } + + [mem_pool->heaps_to_remove removeAllObjects]; + } +} + +static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) { + for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) { + ptr.data->offs = 0; + } +} + +static id ggml_metal_mem_pool_alloc(struct ggml_metal_mem_pool * mem_pool, size_t size) { + const size_t alignment = 256; + + const size_t size_aligned = GGML_PAD(size, alignment); + + // try one of the existing heaps + for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) { + struct ggml_metal_heap * heap = ptr.data; + if (heap->offs + size_aligned <= [heap->obj size]) { + // if this is the first buffer in the heap for the current command buffer, tell the OS that + // it cannot free the memory used by the heap + // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc + if ([heap->bufs count] == 0) { + [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile]; + } + + id buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs]; + if (buf == nil) { + GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned); + return nil; + } + + heap->n_alloc++; + heap->offs += size_aligned; + + [heap->bufs addObject:buf]; + + return buf; + } + } + + // create a new heap that can fit this buffer + ggml_metal_heap_ptr * heap_ptr = [ggml_metal_heap_ptr new]; + + struct ggml_metal_heap * heap = ggml_metal_heap_init(mem_pool->device, size_aligned); + if (heap == NULL) { + GGML_LOG_ERROR("%s: error: failed to create heap of size %zu\n", __func__, size_aligned); + return NULL; + } + + //GGML_LOG_DEBUG("%s: creating new heap of size %zu, got %zu\n", __func__, size_aligned, [heap->obj size]); + + heap_ptr.data = heap; + ggml_metal_heap_reset(heap); + + [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile]; + id buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs]; + if (buf == nil) { + GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned); + return NULL; + } + + heap->n_alloc++; + heap->offs += size_aligned; + + [heap->bufs addObject:buf]; + + [mem_pool->heaps addObject:heap_ptr]; + mem_pool->n_heaps++; + + return buf; +} + +struct ggml_metal_command_buffer { + id obj; + + // each command buffer has a memory pool from which it can allocate temporary buffers during the compute + struct ggml_metal_mem_pool * mem_pool; +}; + struct ggml_backend_metal_context { + id device; id queue; dispatch_queue_t d_queue; @@ -493,7 +787,7 @@ struct ggml_backend_metal_context { void (^encode_async)(size_t ith); // n_cb command buffers + 1 used by the main thread - id command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1]; + struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1]; // abort ggml_metal_graph_compute if callback returns true ggml_abort_callback abort_callback; @@ -683,9 +977,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de struct ggml_backend_metal_device_context * ctx_dev = dev->context; id device = ggml_backend_metal_device_acq(ctx_dev); + GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]); - ctx->queue = [device newCommandQueue]; + ctx->device = device; + ctx->queue = [device newCommandQueue]; if (ctx->queue == nil) { GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__); return NULL; @@ -746,7 +1042,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de ctx->gf = nil; ctx->encode_async = nil; for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { - ctx->command_buffers[i] = nil; + ctx->cmd_bufs[i].obj = nil; + + ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init(); + ctx->cmd_bufs[i].mem_pool->device = device; } #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) @@ -806,6 +1105,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF, gelu_erf, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF_4, gelu_erf_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); @@ -963,30 +1264,36 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, mul_mm_id_bf16_f32, has_simdgroup_mm && use_bfloat); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, has_simdgroup_mm); - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16, mul_mm_id_q4_0_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16, mul_mm_id_q4_1_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16, mul_mm_id_q5_K_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16, mul_mm_id_q6_K_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16, mul_mm_id_iq2_xxs_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16, mul_mm_id_iq2_xs_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16, mul_mm_id_iq3_xxs_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16, mul_mm_id_iq3_s_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16, mul_mm_id_iq2_s_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16, mul_mm_id_iq1_s_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, mul_mm_id_iq1_m_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, mul_mm_id_iq4_nl_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, rope_multi_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, rope_vision_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, rope_vision_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); @@ -1011,6 +1318,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192, flash_attn_ext_f16_h192, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128, flash_attn_ext_f16_hk192_hv128, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512, flash_attn_ext_f16_hk576_hv512, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && use_bfloat); @@ -1019,6 +1327,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192, flash_attn_ext_bf16_h192, has_simdgroup_mm && use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm); @@ -1027,6 +1336,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192, flash_attn_ext_q4_0_h192, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256, flash_attn_ext_q4_0_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64, flash_attn_ext_q4_1_h64, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80, flash_attn_ext_q4_1_h80, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96, flash_attn_ext_q4_1_h96, has_simdgroup_mm); @@ -1035,6 +1345,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192, flash_attn_ext_q4_1_h192, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256, flash_attn_ext_q4_1_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64, flash_attn_ext_q5_0_h64, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80, flash_attn_ext_q5_0_h80, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96, flash_attn_ext_q5_0_h96, has_simdgroup_mm); @@ -1043,6 +1354,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192, flash_attn_ext_q5_0_h192, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256, flash_attn_ext_q5_0_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64, flash_attn_ext_q5_1_h64, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80, flash_attn_ext_q5_1_h80, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96, flash_attn_ext_q5_1_h96, has_simdgroup_mm); @@ -1051,6 +1363,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192, flash_attn_ext_q5_1_h192, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256, flash_attn_ext_q5_1_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64, flash_attn_ext_q8_0_h64, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80, flash_attn_ext_q8_0_h80, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96, flash_attn_ext_q8_0_h96, has_simdgroup_mm); @@ -1059,6 +1372,21 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192, flash_attn_ext_q8_0_h192, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64, flash_attn_ext_vec_q4_1_h64, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64, flash_attn_ext_vec_q5_0_h64, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64, flash_attn_ext_vec_q5_1_h64, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64, flash_attn_ext_vec_q8_0_h64, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96, flash_attn_ext_vec_f16_h96, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96, flash_attn_ext_vec_bf16_h96, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96, flash_attn_ext_vec_q4_0_h96, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96, flash_attn_ext_vec_q4_1_h96, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96, flash_attn_ext_vec_q5_0_h96, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96, flash_attn_ext_vec_q5_1_h96, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96, flash_attn_ext_vec_q8_0_h96, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction); @@ -1087,6 +1415,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512, flash_attn_ext_vec_f16_hk576_hv512, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512, flash_attn_ext_vec_bf16_hk576_hv512, has_simdgroup_reduction && use_bfloat); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512, flash_attn_ext_vec_q4_0_hk576_hv512, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512, flash_attn_ext_vec_q4_1_hk576_hv512, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); @@ -1117,6 +1452,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true); @@ -1137,6 +1473,12 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { [ctx->queue release]; + for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { + // ctx->cmd_bufs[i].obj is auto released + + ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool); + } + dispatch_release(ctx->d_queue); free(ctx); @@ -1275,9 +1617,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_NEG: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; default: return false; @@ -1320,16 +1664,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_NORM: return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); case GGML_OP_ROPE: - { - const int mode = ((const int32_t *) op->op_params)[2]; - if (mode & GGML_ROPE_TYPE_MROPE) { - return false; - } - if (mode & GGML_ROPE_TYPE_VISION) { - return false; - } - return true; - } + return true; case GGML_OP_IM2COL: return op->src[0]->type == GGML_TYPE_F16; case GGML_OP_POOL_1D: @@ -1351,6 +1686,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex // TODO: not sure if it is worth adding kernels for this size return false; } + if (op->src[0]->ne[0] == 576) { + // DeepSeek sizes + // TODO: disabled for now, until optmized + return false; + } if (op->src[1]->type != op->src[2]->type) { return false; } @@ -1436,10 +1776,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex } } -static void ggml_metal_encode_node( +static bool ggml_metal_encode_node( ggml_backend_t backend, int idx, - id encoder) { + id encoder, + struct ggml_metal_mem_pool * mem_pool) { struct ggml_backend_metal_context * ctx = backend->context; struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; @@ -1455,7 +1796,7 @@ static void ggml_metal_encode_node( struct ggml_tensor * dst = node; if (ggml_is_empty(dst)) { - return; + return true; } switch (dst->op) { @@ -1466,7 +1807,7 @@ static void ggml_metal_encode_node( case GGML_OP_PERMUTE: { // noop -> next node - } return; + } return true; default: { } break; @@ -1477,6 +1818,8 @@ static void ggml_metal_encode_node( GGML_ABORT("unsupported op"); } + ggml_metal_mem_pool_clear(mem_pool); + const int64_t ne00 = src0 ? src0->ne[0] : 0; const int64_t ne01 = src0 ? src0->ne[1] : 0; const int64_t ne02 = src0 ? src0->ne[2] : 0; @@ -1913,6 +2256,25 @@ static void ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_UNARY_OP_GELU_ERF: + { + int64_t n = ggml_nelements(dst); + + id pipeline = nil; + + if (n % 4 == 0) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF_4].pipeline; + n /= 4; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF].pipeline; + } + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; case GGML_UNARY_OP_GELU_QUICK: { int64_t n = ggml_nelements(dst); @@ -1963,6 +2325,18 @@ static void ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_UNARY_OP_NEG: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; default: { GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); @@ -2111,26 +2485,76 @@ static void ggml_metal_encode_node( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_metal_kargs_soft_max args = { +// use this branch to test the ggml_metal_mem_pool functionality +#if 0 + // cpy to tmp buffer in MTLHeap + + id h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0)); + if (!h_src0) { + GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0)); + return false; + } + + offs_src0 = 0; + + ggml_metal_kargs_cpy args_cpy = { /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, /*.ne02 =*/ ne02, - /*.scale =*/ scale, - /*.max_bias =*/ max_bias, - /*.m0 =*/ m0, - /*.m1 =*/ m1, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne00, + /*.ne1 =*/ ne01, + /*.ne2 =*/ ne02, + /*.ne3 =*/ ne03, + /*.nb0 =*/ nb00, + /*.nb1 =*/ nb01, + /*.nb2 =*/ nb02, + /*.nb3 =*/ nb03, + }; + + if (src0->type == GGML_TYPE_F16) { + [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline]; + } else { + [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline]; + } + [encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:h_src0 offset:0 atIndex:2]; + + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + int nth_cpy = MIN(1024, ne00 / ggml_blck_size(src0->type)); + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)]; + +#else + id h_src0 = id_src0; +#endif + // softmax + + ggml_metal_kargs_soft_max args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, /*.n_head_log2 =*/ n_head_log2, }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:h_src0 offset:offs_src0 atIndex:0]; if (id_src1) { - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:h_src0 offset:offs_src0 atIndex:1]; } - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&args length:sizeof(args) atIndex:3]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&args length:sizeof(args) atIndex:3]; [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; @@ -2621,7 +3045,7 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { id pipeline = nil; @@ -2841,8 +3265,6 @@ static void ggml_metal_encode_node( } break; case GGML_OP_MUL_MAT_ID: { - const int n_as = src0->ne[2]; - // src2 = ids const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t); @@ -2856,24 +3278,21 @@ static void ggml_metal_encode_node( GGML_ASSERT(ne03 == 1); GGML_ASSERT(ne13 == 1); + const uint32_t r2 = 1; + const uint32_t r3 = 1; + // find the break-even point where the matrix-matrix kernel becomes more efficient compared // to the matrix-vector kernel // ne20 = n_used_experts - // ne21 = n_rows - const int dst_rows = ne20*ne21; - const int dst_rows_min = n_as; - const int dst_rows_max = (device.maxThreadgroupMemoryLength/2 - 8192)/4; - - // max size of the rowids array in the kernel shared buffer - //GGML_ASSERT(dst_rows <= dst_rows_max); + // ne21 = n_rows (batch size) + const int ne21_mm_id_min = 32; // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel if ([device supportsFamily:MTLGPUFamilyApple7] && ne00 % 32 == 0 && ne00 >= 64 && - //ne01 / ne02 >= 512 && // NOTE: this is based on Mixtral shapes, might need adjustments - dst_rows > dst_rows_min && - dst_rows <= dst_rows_max) { + (ne21 >= ne21_mm_id_min)) { + GGML_ASSERT(ne00 % 4 == 0); // some Metal matrix data types require aligned pointers // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) @@ -2884,62 +3303,169 @@ static void ggml_metal_encode_node( default: break; } - id pipeline = nil; + const int64_t neh10 = ne10; // n_embd + const int64_t neh11 = ne21; // n_tokens + const int64_t neh12 = ne02; // n_expert - switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break; - case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32 ].pipeline; break; - case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break; - case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32 ].pipeline; break; - case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32 ].pipeline; break; - case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32 ].pipeline; break; - case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32 ].pipeline; break; - case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32 ].pipeline; break; - case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32 ].pipeline; break; - case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32 ].pipeline; break; - case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32 ].pipeline; break; - case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32 ].pipeline; break; - case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break; - case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break; - case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline; break; - case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline; break; - case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break; - case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32 ].pipeline; break; - case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break; - case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break; - default: GGML_ABORT("MUL_MAT_ID not implemented"); + const uint64_t nbh10 = ggml_type_size(GGML_TYPE_F16); + const uint64_t nbh11 = nbh10*neh10; + const uint64_t nbh12 = nbh11*neh11; + const uint64_t nbh13 = nbh12*neh12; + + const size_t s_src1 = ggml_type_size(GGML_TYPE_F16)*neh10*neh11*neh12; + id h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1); + if (!h_src1) { + GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1); + return false; } - ggml_metal_kargs_mul_mm_id args = { - /*.nei0 =*/ ne20, - /*.nei1 =*/ ne21, - /*.nbi1 =*/ nb21, - /*.ne00 =*/ ne00, - /*.ne02 =*/ ne02, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.ne13 =*/ ne13, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - }; + const int64_t neh0 = ne0; + const int64_t neh1 = ne21; + const int64_t neh2 = ne02; - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; + const uint64_t nbh0 = ggml_type_size(GGML_TYPE_F32); + const uint64_t nbh1 = nbh0*neh0; + const uint64_t nbh2 = nbh1*neh1; + //const uint64_t nbh3 = nbh2*neh2; - [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0]; + const size_t s_dst = ggml_type_size(GGML_TYPE_F32)*neh0*neh1*neh2; + id h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst); + if (!h_dst) { + GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst); + return false; + } - [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + // tokens per expert + const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02; + id h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe); + if (!h_tpe) { + GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe); + return false; + } + + // id map + // [n_expert_used, n_tokens] + const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne20*ne21; + id h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids); + if (!h_ids) { + GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids); + return false; + } + + { + const int nth = MIN(1024, ne10/4); + + ggml_metal_kargs_mul_mm_id_map0 args = { + ne10, + ne11, // n_expert_used (bcast) + nb11, + nb12, + neh11, // n_tokens + nbh11, + ne20, // n_expert_used + nb21, + }; + + id pipeline = nil; + + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer: h_src1 offset:0 atIndex:3]; + [encoder setBuffer: h_tpe offset:0 atIndex:4]; + [encoder setBuffer: h_ids offset:0 atIndex:5]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } + + { + id pipeline = nil; + + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16 ].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16 ].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16 ].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16 ].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16 ].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break; + case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break; + case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break; + case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break; + case GGML_TYPE_Q5_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16 ].pipeline; break; + case GGML_TYPE_Q6_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16 ].pipeline; break; + case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16].pipeline; break; + case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16 ].pipeline; break; + case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16].pipeline; break; + case GGML_TYPE_IQ3_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16 ].pipeline; break; + case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16 ].pipeline; break; + case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16 ].pipeline; break; + case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16 ].pipeline; break; + case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16 ].pipeline; break; + case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16 ].pipeline; break; + default: GGML_ABORT("MUL_MAT_ID not implemented"); + } + + ggml_metal_kargs_mul_mm_id args = { + /*.ne00 =*/ ne00, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.neh12 =*/ neh12, + /*.nbh10 =*/ nbh10, + /*.nbh11 =*/ nbh11, + /*.nbh12 =*/ nbh12, + /*.nbh13 =*/ nbh13, + /*.neh0 =*/ neh0, + /*.neh1 =*/ neh1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer: h_src1 offset:0 atIndex:2]; + [encoder setBuffer: h_tpe offset:0 atIndex:3]; + [encoder setBuffer: h_dst offset:0 atIndex:4]; + + [encoder setThreadgroupMemoryLength:8192 atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + } + + { + GGML_ASSERT(ne0 % 4 == 0); + + const int nth = MIN(1024, ne0/4); + + ggml_metal_kargs_mul_mm_id_map1 args = { + ne20, // n_expert_used + neh0, + neh1, + nbh1, + nbh2, + ne0, + nb1, + nb2, + }; + + id pipeline = nil; + + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer: h_dst offset:0 atIndex:1]; + [encoder setBuffer: h_ids offset:0 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } } else { id pipeline = nil; @@ -3133,7 +3659,7 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; const int64_t _ne1 = 1; - const int64_t ne123 = dst_rows; + const int64_t ne123 = ne20*ne21; if (smem > 0) { [encoder setThreadgroupMemoryLength:smem atIndex:0]; @@ -3337,6 +3863,7 @@ static void ggml_metal_encode_node( } break; case GGML_OP_ROPE: { + // make sure we have one or more position id(ne10) per token(ne02) GGML_ASSERT(ne10 % ne02 == 0); GGML_ASSERT(ne10 >= ne02); @@ -3363,20 +3890,42 @@ static void ggml_metal_encode_node( memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float)); - const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + // mrope + const int sect_0 = ((const int32_t *) dst->op_params)[11]; + const int sect_1 = ((const int32_t *) dst->op_params)[12]; + const int sect_2 = ((const int32_t *) dst->op_params)[13]; + const int sect_3 = ((const int32_t *) dst->op_params)[14]; id pipeline = nil; - if (!is_neox) { + if (is_neox) { switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + } else if (is_mrope && !is_vision) { + GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + } else if (is_vision) { + GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break; default: GGML_ABORT("fatal error"); }; } else { switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; default: GGML_ABORT("fatal error"); }; } @@ -3407,6 +3956,10 @@ static void ggml_metal_encode_node( /*.attn_factor =*/ attn_factor, /*.beta_fast =*/ beta_fast, /*.beta_slow =*/ beta_slow, + /* sect_0 =*/ sect_0, + /* sect_1 =*/ sect_1, + /* sect_2 =*/ sect_2, + /* sect_3 =*/ sect_3, }; [encoder setComputePipelineState:pipeline]; @@ -3843,12 +4396,14 @@ static void ggml_metal_encode_node( // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0) // for now avoiding mainly to keep the number of templates/kernels a bit lower // these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612 - if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 192)) { + if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 64 && ne00 != 96 && ne00 != 192 && ne00 != 576)) { switch (src1->type) { case GGML_TYPE_F16: { if (ne00 == 192 && ne20 == 128) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline; + } else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline; } else { switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; @@ -3871,6 +4426,8 @@ static void ggml_metal_encode_node( { if (ne00 == 192 && ne20 == 128) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline; + } else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline; } else { switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break; @@ -3893,6 +4450,8 @@ static void ggml_metal_encode_node( { if (ne00 == 192 && ne20 == 128) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline; + } else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline; } else { switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break; @@ -3915,6 +4474,8 @@ static void ggml_metal_encode_node( { if (ne00 == 192 && ne20 == 128) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline; + } else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline; } else { switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break; @@ -3937,6 +4498,8 @@ static void ggml_metal_encode_node( { if (ne00 == 192 && ne20 == 128) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline; + } else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline; } else { switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break; @@ -3959,6 +4522,8 @@ static void ggml_metal_encode_node( { if (ne00 == 192 && ne20 == 128) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline; + } else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline; } else { switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break; @@ -3981,6 +4546,8 @@ static void ggml_metal_encode_node( { if (ne00 == 192 && ne20 == 128) { pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline; + } else if (ne00 == 576 && ne20 == 512) { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline; } else { switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break; @@ -4010,6 +4577,42 @@ static void ggml_metal_encode_node( use_vec_kernel = true; switch (ne00) { + case 64: + { + switch (src1->type) { + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H64].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H64].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H64].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H64].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + GGML_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + } break; + case 96: + { + switch (src1->type) { + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H96].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H96].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H96].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H96].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + GGML_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + } break; case 128: { switch (src1->type) { @@ -4082,12 +4685,36 @@ static void ggml_metal_encode_node( } } } break; + case 576: + { + if (ne20 == 512) { + switch (src1->type) { + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; break; + case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512].pipeline; break; + case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512].pipeline; break; + case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512].pipeline; break; + case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512].pipeline; break; + case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512].pipeline; break; + case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; break; + default: + { + GGML_LOG_ERROR("unsupported type: %d\n", src1->type); + GGML_LOG_ERROR("add template specialization for this type\n"); + GGML_ABORT("add template specialization for this type"); + } + } + } else { + GGML_LOG_ERROR("unsupported size: %lld\n", ne20); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } break; default: - { - GGML_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } + { + GGML_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } } } @@ -4139,6 +4766,8 @@ static void ggml_metal_encode_node( GGML_ASSERT(nqptg % 8 == 0); GGML_ASSERT(ncpsg % 32 == 0); + const int is_q = ggml_is_quantized(src1->type) ? 1 : 0; + // 2*(2*ncpsg + nqptg)*(nsg) // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float) // @@ -4146,7 +4775,7 @@ static void ggml_metal_encode_node( // the shared memory needed for the simdgroups to load the KV cache // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16)) +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16)) int64_t nsgmax = 2; @@ -4183,9 +4812,9 @@ static void ggml_metal_encode_node( // and store the soft_max values and the mask // // ne00*(nsg) - // each simdgroup has a full f16 head vector in shared mem to accumulate results + // each simdgroup has a full f32 head vector in shared mem to accumulate results // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16)) +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16)) int64_t nsgmax = 2; while (true) { @@ -4483,6 +5112,8 @@ static void ggml_metal_encode_node( GGML_ABORT("fatal error"); } } + + return true; } static enum ggml_status ggml_metal_graph_compute( @@ -4536,25 +5167,25 @@ static enum ggml_status ggml_metal_graph_compute( } // the main thread commits the first few commands immediately - // command_buffer[n_cb] + // cmd_buf[n_cb] { - id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; - ctx->command_buffers[n_cb] = command_buffer; + id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + ctx->cmd_bufs[n_cb].obj = cmd_buf; - [command_buffer enqueue]; + [cmd_buf enqueue]; ctx->encode_async(n_cb); } // prepare the rest of the command buffers asynchronously - // command_buffer[0.. n_cb) + // cmd_buf[0.. n_cb) for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { - id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; - ctx->command_buffers[cb_idx] = command_buffer; + id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + ctx->cmd_bufs[cb_idx].obj = cmd_buf; // always enqueue the first two command buffers // enqueue all of the command buffers if we don't need to abort if (cb_idx < 2 || ctx->abort_callback == NULL) { - [command_buffer enqueue]; + [cmd_buf enqueue]; } } @@ -4563,14 +5194,14 @@ static enum ggml_status ggml_metal_graph_compute( // wait for completion and check status of each command buffer // needed to detect if the device ran out-of-memory for example (#1881) { - id command_buffer = ctx->command_buffers[n_cb]; - [command_buffer waitUntilCompleted]; + id cmd_buf = ctx->cmd_bufs[n_cb].obj; + [cmd_buf waitUntilCompleted]; - MTLCommandBufferStatus status = [command_buffer status]; + MTLCommandBufferStatus status = [cmd_buf status]; if (status != MTLCommandBufferStatusCompleted) { GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status); if (status == MTLCommandBufferStatusError) { - GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]); + GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); } return GGML_STATUS_FAILED; @@ -4578,20 +5209,20 @@ static enum ggml_status ggml_metal_graph_compute( } for (int i = 0; i < n_cb; ++i) { - id command_buffer = ctx->command_buffers[i]; - [command_buffer waitUntilCompleted]; + id cmd_buf = ctx->cmd_bufs[i].obj; + [cmd_buf waitUntilCompleted]; - MTLCommandBufferStatus status = [command_buffer status]; + MTLCommandBufferStatus status = [cmd_buf status]; if (status != MTLCommandBufferStatusCompleted) { GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); if (status == MTLCommandBufferStatusError) { - GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]); + GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); } return GGML_STATUS_FAILED; } - id next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil); + id next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil); if (!next_buffer) { continue; } @@ -4974,8 +5605,9 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { const int n_nodes_per_cb = ctx->n_nodes_per_cb; - id command_buffer = ctx->command_buffers[cb_idx]; - id encoder = [command_buffer computeCommandEncoder]; + id cmd_buf = ctx->cmd_bufs[cb_idx].obj; + + id encoder = [cmd_buf computeCommandEncoder]; int node_start = 0; int node_end = n_nodes_0; @@ -4987,22 +5619,29 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { const bool should_capture = ctx->capture_next_compute; + struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool; + ggml_metal_mem_pool_reset(mem_pool); + for (int idx = node_start; idx < node_end; ++idx) { if (should_capture) { [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]]; } - ggml_metal_encode_node(backend, idx, encoder); + const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool); if (should_capture) { [encoder popDebugGroup]; } + + if (!res) { + break; + } } [encoder endEncoding]; if (cb_idx < 2 || ctx->abort_callback == NULL) { - [command_buffer commit]; + [cmd_buf commit]; } }); } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index b08666e27..5d7760217 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -856,6 +856,7 @@ kernel void kernel_tanh( constant float GELU_COEF_A = 0.044715f; constant float GELU_QUICK_COEF = -1.702f; constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; +constant float SQRT_2_INV = 0.70710678118654752440084436210484f; kernel void kernel_gelu( device const float * src0, @@ -897,6 +898,42 @@ kernel void kernel_gelu_quick_4( dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); } +// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation +// ref: https://www.johndcook.com/blog/python_erf/ +constant float p_erf = 0.3275911f; +constant float a1_erf = 0.254829592f; +constant float a2_erf = -0.284496736f; +constant float a3_erf = 1.421413741f; +constant float a4_erf = -1.453152027f; +constant float a5_erf = 1.061405429f; + +template +T erf_approx(T x) { + T sign_x = sign(x); + x = fabs(x); + T t = 1.0f / (1.0f + p_erf * x); + T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x); + return sign_x * y; +} + +kernel void kernel_gelu_erf( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + + dst[tpig] = 0.5f*x*(1.0f+erf_approx(x*SQRT_2_INV)); +} + +kernel void kernel_gelu_erf_4( + device const float4 * src0, + device float4 * dst, + uint tpig[[thread_position_in_grid]]) { + device const float4 & x = src0[tpig]; + + dst[tpig] = 0.5f*x*(1.0f+erf_approx(x*SQRT_2_INV)); +} + kernel void kernel_silu( device const float * src0, device float * dst, @@ -949,6 +986,13 @@ kernel void kernel_cos( dst[tpig] = cos(src0[tpig]); } +kernel void kernel_neg( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = -src0[tpig]; +} + kernel void kernel_sum_rows( device const float * src0, device float * dst, @@ -2706,8 +2750,148 @@ kernel void kernel_rope_neox( } } +template +kernel void kernel_rope_multi( + constant ggml_metal_kargs_rope & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; + + float corr_dims[2]; + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); + + device const int32_t * pos = (device const int32_t *) src1; + + const float inv_ndims = -1.f/args.n_dims; + + float cos_theta; + float sin_theta; + + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < args.n_dims) { + const int ic = i0/2; + + // mrope theta calculations + // note: the rest is the same as kernel_rope_neox + const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3; + const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1 + const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2 + const int sector = ic % sect_dims; + + float theta_base; + if (sector < args.sect_0) { + theta_base = (float) pos[i2]; + } else if (sector < sec_w01) { + theta_base = (float) pos[i2 + args.ne02]; + } else if (sector < sec_w012) { + theta_base = (float) pos[i2 + args.ne02 * 2]; + } else { + theta_base = (float) pos[i2 + args.ne02 * 3]; + } + // end of mrope + + const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); + + const float x0 = src[0]; + const float x1 = src[args.n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +template +kernel void kernel_rope_vision( + constant ggml_metal_kargs_rope & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; + + float corr_dims[2]; + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); + + device const int32_t * pos = (device const int32_t *) src1; + + const float inv_ndims = -1.f/args.n_dims; + + float cos_theta; + float sin_theta; + + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < 2*args.n_dims) { // different from kernel_rope_multi + const int ic = i0/2; + + // mrope theta calculations (only support 2 dimensions) + const int sect_dims = args.sect_0 + args.sect_1; + const int sector = ic % sect_dims; + + float p; + float theta_base; + if (sector < args.sect_1) { + p = (float) sector; + theta_base = (float) pos[i2]; + } else { + p = (float) sector - args.sect_0; + theta_base = (float) pos[i2 + args.ne02]; + } + + const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p); + // end of mrope + + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); + + const float x0 = src[0]; + const float x1 = src[args.n_dims]; // different from kernel_rope_multi + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi + } else { + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + typedef decltype(kernel_rope_norm) kernel_rope_norm_t; typedef decltype(kernel_rope_neox) kernel_rope_neox_t; +typedef decltype(kernel_rope_multi) kernel_rope_multi_t; +typedef decltype(kernel_rope_vision) kernel_rope_vision_t; template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm; template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm; @@ -2715,6 +2899,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_ template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox; template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox; +template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi; +template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi; + +template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision; +template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision; + typedef void (im2col_t)( device const float * x, device char * dst, @@ -3102,7 +3292,7 @@ template< typename kd4x4_t, // key type in device memory short nl_k, void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &), - typename vd4x4_t, // key type in device memory + typename vd4x4_t, // value type in device memory short nl_v, void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), short DK, // K head size @@ -3138,14 +3328,12 @@ kernel void kernel_flash_attn_ext( constexpr short NW = N_SIMDWIDTH; constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float) - const short TS = nsg*SH; // shared memory size per query in (s_t == float) - const short T = DK + 2*TS; // shared memory size per query in (half) + const short TS = nsg*SH; // shared memory size per query in (s_t == float) + const short T = 2*DK + 2*TS; // shared memory size per query in (half) - threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t - threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation - threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix + threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t @@ -3164,7 +3352,7 @@ kernel void kernel_flash_attn_ext( if (iq1 + j < args.ne01) { sq4[j*DK4 + i] = (q4_t) q4[i]; } else { - sq4[j*DK4 + i] = (q4_t) 0.0f; + sq4[j*DK4 + i] = 0; } } } @@ -3185,7 +3373,7 @@ kernel void kernel_flash_attn_ext( { float S[Q] = { [0 ... Q-1] = 0.0f }; - float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 }; + float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 }; // thread indices inside the simdgroup // TODO: see if we can utilize quad-group functions for better performance @@ -3358,20 +3546,20 @@ kernel void kernel_flash_attn_ext( // O = diag(ms)*O { - s8x8_t mm; - simdgroup_load(mm, ss + 2*C, TS, 0, false); + s8x8_t ms; + simdgroup_load(ms, ss + 2*C, TS, 0, false); #pragma unroll(DV8) for (short i = 0; i < DV8; ++i) { - simdgroup_multiply(lo[i], mm, lo[i]); + simdgroup_multiply(lo[i], ms, lo[i]); } } // O = O + (Q*K^T)*V { for (short cc = 0; cc < C/8; ++cc) { - s8x8_t ms; - simdgroup_load(ms, ss + 8*cc, TS, 0, false); + s8x8_t vs; + simdgroup_load(vs, ss + 8*cc, TS, 0, false); if (is_same::value) { // we can read directly from global memory @@ -3382,7 +3570,7 @@ kernel void kernel_flash_attn_ext( v8x8_t mv; simdgroup_load(mv, pv + i*8, args.nb21/sizeof(v_t), 0, false); // TODO: use ne20 - simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]); + simdgroup_multiply_accumulate(lo[i], vs, mv, lo[i]); } } else { for (short ii = 0; ii < DV16; ii += 4) { @@ -3403,10 +3591,10 @@ kernel void kernel_flash_attn_ext( v8x8_t mv; simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]); simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]); } } else { if (ii + tx < DV16) { @@ -3421,10 +3609,10 @@ kernel void kernel_flash_attn_ext( v8x8_t mv; simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], vs, mv, lo[2*(ii + k) + 0]); simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false); - simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]); + simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], vs, mv, lo[2*(ii + k) + 1]); } } } @@ -3434,93 +3622,89 @@ kernel void kernel_flash_attn_ext( } // these are needed for reducing the results from the simdgroups (reuse the ss buffer) - for (short j = 0; j < Q; ++j) { - if (tiisg == 0) { - ss[j*TS + 0] = S[j]; - ss[j*TS + 1] = M[j]; - } + for (short j = tiisg; j < Q; j += NW) { + ss[j*TS + 0] = S[j]; + ss[j*TS + 1] = M[j]; } } + threadgroup_barrier(mem_flags::mem_threadgroup); + + threadgroup float * so = (threadgroup float *) (shmem_f16 + 0*DK); // reuse query data for accumulation + threadgroup float4 * so4 = (threadgroup float4 *) (shmem_f16 + 0*DK); + + // store result to shared memory in F32 + if (sgitg == 0) { + for (short i = 0; i < DV8; ++i) { + //simdgroup_store(lo[i], so + i*8, DV, 0, false); + simdgroup_float8x8 t(1.0f); + simdgroup_multiply(t, lo[i], t); + simdgroup_store(t, so + i*8, DV, 0, false); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + // reduce the warps sequentially for (ushort sg = 1; sg < nsg; ++sg) { - float S = { 0.0f }; - float M = { -__FLT16_MAX__/2 }; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // each simdgroup stores its output to shared memory, reusing sq if (sgitg == sg) { - for (short i = 0; i < DV8; ++i) { - simdgroup_store(lo[i], so + i*8, DV, 0, false); + for (short j = tiisg; j < Q; j += NW) { + const float S0 = ss[j*TS - 1*SH + 0]; + const float S1 = ss[j*TS + 0]; + + const float M0 = ss[j*TS - 1*SH + 1]; + const float M1 = ss[j*TS + 1]; + + const float M = max(M0, M1); + + float ms0 = exp(M0 - M); + float ms1 = exp(M1 - M); + + const float S = S0*ms0 + S1*ms1; + + ss[j*TS + 0] = S; + ss[j*TS + 1] = M; + + ss[j*TS + 2*C + j - 1*SH] = ms0; + ss[j*TS + 2*C + j ] = ms1; } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // the first simdgroup accumulates the results from the other simdgroups - if (sgitg == 0) { - for (short j = 0; j < Q; ++j) { - const float S0 = ss[j*TS + 0]; - const float S1 = ss[j*TS + sg*SH + 0]; - - const float M0 = ss[j*TS + 1]; - const float M1 = ss[j*TS + sg*SH + 1]; - - M = max(M0, M1); - - const float ms0 = exp(M0 - M); - const float ms1 = exp(M1 - M); - - S = S0*ms0 + S1*ms1; - - if (tiisg == 0) { - ss[j*TS + 0] = S; - ss[j*TS + 1] = M; - - ss[j*TS + 2*C + j ] = ms0; - ss[j*TS + 2*C + j + sg*SH] = ms1; - } - } + //simdgroup_barrier(mem_flags::mem_threadgroup); // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 { s8x8_t ms0; s8x8_t ms1; - simdgroup_load(ms0, ss + 2*C, TS, 0, false); - simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false); + simdgroup_load(ms0, ss + 2*C - 1*SH, TS, 0, false); + simdgroup_load(ms1, ss + 2*C, TS, 0, false); #pragma unroll(DV8) for (short i = 0; i < DV8; ++i) { - o8x8_t t; + simdgroup_float8x8 t; simdgroup_load (t, so + i*8, DV, 0, false); - simdgroup_multiply(t, ms1, t); + simdgroup_multiply(t, ms0, t); - simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); + simdgroup_multiply_accumulate(t, ms1, lo[i], t); + simdgroup_store(t, so + i*8, DV, 0, false); } } } + + threadgroup_barrier(mem_flags::mem_threadgroup); } - // store result to shared memory (reuse sq) - if (sgitg == 0) { - for (short i = 0; i < DV8; ++i) { - simdgroup_store(lo[i], so + i*8, DV, 0, false); - } - } - - device float4 * dst4 = (device float4 *) dst; + threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*(nsg-1)*SH + 2*Q*DK); // final rescale with 1/S and store to global memory - if (sgitg == 0) { - for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) { - const float S = ss[j*TS + 0]; + for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) { + const float S = 1.0f/sf[j*TS + 0]; - for (short i = tiisg; i < DV4; i += NW) { - dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S; - } + device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4; + + for (short i = tiisg; i < DV4; i += NW) { + dst4[i] = (float4) so4[j*DV4 + i]*S; } } } @@ -3529,12 +3713,22 @@ kernel void kernel_flash_attn_ext( // template to be able to explore different combinations // #define FA_TYPES \ - half, half4, simdgroup_half8x8, \ - half, half4x4, simdgroup_half8x8, \ - half, half4x4, simdgroup_half8x8, \ - float, simdgroup_float8x8, \ - float, simdgroup_float8x8, \ - half, half4, simdgroup_half8x8 + float, float4, simdgroup_float8x8, \ + half, half4x4, simdgroup_half8x8, \ + half, half4x4, simdgroup_half8x8, \ + float, simdgroup_float8x8, \ + float, simdgroup_float8x8, \ + half, half4, simdgroup_half8x8 + //float, float4, simdgroup_float8x8 + +#define FA_TYPES_BF \ + bfloat, bfloat4, simdgroup_bfloat8x8, \ + bfloat, bfloat4x4, simdgroup_bfloat8x8, \ + bfloat, bfloat4x4, simdgroup_bfloat8x8, \ + float, simdgroup_float8x8, \ + float, simdgroup_float8x8, \ + half, half4, simdgroup_half8x8 + //float, float4, simdgroup_float8x8 typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; @@ -3546,16 +3740,18 @@ template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #endif template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -3566,6 +3762,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -3575,6 +3772,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -3584,6 +3782,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -3593,6 +3792,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_1_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -3602,8 +3802,10 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q8_0_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #undef FA_TYPES +#undef FA_TYPES_BF template< typename q4_t, // query types in shared memory @@ -3616,7 +3818,7 @@ template< typename kd4_t, // key type in device memory short nl_k, void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &), - typename vd4_t, // key type in device memory + typename vd4_t, // value type in device memory short nl_v, void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), short DK, // K head size @@ -3650,12 +3852,12 @@ kernel void kernel_flash_attn_ext_vec( const short T = DK + nsg*SH; // shared memory size per query in (half) - //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention - threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t - threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask - threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t + threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask + threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*DV + Q*T); // scratch buffer for the results // store the result for all queries in local memory (the O matrix from the paper) o4_t lo[DV4/NL]; @@ -3685,7 +3887,7 @@ kernel void kernel_flash_attn_ext_vec( { float S = 0.0f; - float M = -__FLT16_MAX__/2; + float M = -__FLT_MAX__/2; // thread indices inside the simdgroup const short tx = tiisg%NL; @@ -3727,6 +3929,11 @@ kernel void kernel_flash_attn_ext_vec( sm[tiisg] = pm[ic + tiisg]; } + // skip -INF blocks + if (simd_max(sm[tiisg]) == -INFINITY) { + continue; + } + // Q*K^T { // each simdgroup processes 1 query and NE (NW/NL) head elements @@ -3955,10 +4162,30 @@ kernel void kernel_flash_attn_ext_vec( half4, \ float, \ float, float4, \ - half4 + float4 typedef decltype(kernel_flash_attn_ext_vec) flash_attn_ext_vec_t; +template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + +template [[host_name("kernel_flash_attn_ext_vec_f16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_h96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_USE_BF16) template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; @@ -3999,6 +4226,16 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_USE_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + #undef FA_TYPES template @@ -6302,127 +6539,219 @@ kernel void kernel_mul_mm( } } -// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids -// TODO: this kernel needs to be reimplemented from scratch for better performance -template -void kernel_mul_mm_id_impl( - int32_t ne00, - int32_t ne02, - uint64_t nb01, - uint64_t nb02, - int32_t ne11, - int32_t ne12, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - int32_t ne0, - int32_t ne1, - int64_t ne0ne1, - device const char * src0, - device const char * src1, - threadgroup ushort2 * rowids, - device char * dst, - threadgroup char * shmem, +template +kernel void kernel_mul_mm_id_map0( + constant ggml_metal_kargs_mul_mm_id_map0 & args, + device const char * src1, + device const char * src2, + device char * hsrc1, + device char * htpe, + device char * hids, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int ide = tgpig[0]; // expert id + + int n_all = 0; + + device int32_t * ids_i32 = (device int32_t *) (hids); + + for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens + device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21); + + for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used + if (src2_i32[i20] != ide) { + continue; + } + + device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11); + device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11); + + for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) { + hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]); + } + + if (tpitg.x == 0) { + ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all; + } + + ++n_all; + } + } + + if (tpitg.x == 0) { + device int32_t * tpe_i32 = (device int32_t *) (htpe); + tpe_i32[ide] = n_all; + } +} + +typedef decltype(kernel_mul_mm_id_map0) kernel_mul_mm_id_map0_t; + +template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0; + +template +kernel void kernel_mul_mm_id_map1( + constant ggml_metal_kargs_mul_mm_id_map1 & args, + device const char * hdst, + device const char * hids, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i20 = tgpig[0]; // used expert + const int i21 = tgpig[1]; // token + + device const int32_t * ids_i32 = (device const int32_t *) (hids); + device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2); + + const int id = ids_i32[i21*args.ne20 + i20]; + + const int ide = id / args.neh1; + const int idt = id % args.neh1; + + device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2); + + for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) { + dst_f32x4[i0] = hdst_f32x4[i0]; + } +} + +typedef decltype(kernel_mul_mm_id_map1) kernel_mul_mm_id_map1_t; + +template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1; + +template +kernel void kernel_mul_mm_id( + constant ggml_metal_kargs_mul_mm_id & args, + device const char * src0, + device const char * src1, + device const char * tpe, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiitg[[thread_index_in_threadgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - threadgroup half * sa = (threadgroup half *)(shmem); - threadgroup float * sb = (threadgroup float *)(shmem + 4096); + threadgroup T * sa = (threadgroup T *)(shmem); + threadgroup half * sb = (threadgroup half *)(shmem + 4096); const int r0 = tgpig.y; const int r1 = tgpig.x; + const int im = tgpig.z; - if (r1*BLOCK_SIZE_N >= ne1) return; + device const int32_t * tpe_i32 = (device const int32_t *) (tpe); + + const int neh1 = tpe_i32[im]; + + if (r1*BLOCK_SIZE_N >= neh1) { + return; + } // if this block is of 64x32 shape or smaller - short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; - short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; + const short n_rows = (args.neh0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.neh0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; + const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? ( neh1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; // a thread shouldn't load data outside of the matrix - short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; - short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; + const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; + const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - simdgroup_half8x8 ma[4]; - simdgroup_float8x8 mb[2]; + simdgroup_T8x8 ma[4]; + simdgroup_half8x8 mb[2]; simdgroup_float8x8 mc[8]; - for (int i = 0; i < 8; i++){ + + for (short i = 0; i < 8; i++){ mc[i] = make_filled_simdgroup_matrix(0.f); } + short il = (tiitg % THREAD_PER_ROW); - ushort offset1 = il/nl; + const int i12 = im%args.neh12; + const int i13 = im/args.neh12; - threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col]; + const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const short offset1 = il/nl; - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1; - device const float * y = (device const float *)(src1 - + nb12 * id[1] - + nb11 * (id[0] % ne11) - + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + device const block_q * x = (device const block_q *)(src0 + + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1; - for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { + device const half * y = (device const half *)(src1 + + args.nbh13*i13 + + args.nbh12*i12 + + args.nbh11*(r1*BLOCK_SIZE_N + thread_col) + + args.nbh10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + + for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) { // load data and store to threadgroup memory - half4x4 temp_a; + T4x4 temp_a; dequantize_func(x, il, temp_a); + threadgroup_barrier(mem_flags::mem_threadgroup); - for (int i = 0; i < 16; i++) { - *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ - + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ - + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; + #pragma unroll(16) + for (short i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \ + + (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \ + + (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4]; } - *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); + *(threadgroup half2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device half2x4 *) y); il = (il + 2 < nl) ? il + 2 : il % 2; - x = (il < 2) ? x + (2+nl-1)/nl : x; + x = (il < 2) ? x + (2 + nl - 1)/nl : x; y += BLOCK_SIZE_K; threadgroup_barrier(mem_flags::mem_threadgroup); // load matrices from threadgroup memory and conduct outer products - threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); - threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + threadgroup const T * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2)); + threadgroup const half * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2)); - #pragma unroll(BLOCK_SIZE_K/8) - for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { + #pragma unroll(4) + for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) { #pragma unroll(4) - for (int i = 0; i < 4; i++) { + for (short i = 0; i < 4; i++) { simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); } + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(2) - for (int i = 0; i < 2; i++) { + for (short i = 0; i < 2; i++) { simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); } - lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; - lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; - #pragma unroll(8) - for (int i = 0; i < 8; i++){ + for (short i = 0; i < 8; i++){ simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); } + + lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE; + lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE; } } - { + if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) { + device float * C = (device float *) dst + + (BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \ + (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0; + + for (short i = 0; i < 8; i++) { + simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0); + } + } else { + // block is smaller than 64x32, we should avoid writing data outside of the matrix threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup float * temp_str = ((threadgroup float *) shmem) \ - + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; - for (int i = 0; i < 8; i++) { - simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); + + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M; + for (short i = 0; i < 8; i++) { + simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); } threadgroup_barrier(mem_flags::mem_threadgroup); if (sgitg == 0) { for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j]; - int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1; - - device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff; + device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0; device float4 * D4 = (device float4 *) D; threadgroup float * C = temp_str + (j*BLOCK_SIZE_M); @@ -6442,66 +6771,6 @@ void kernel_mul_mm_id_impl( } } -template -kernel void kernel_mul_mm_id( - constant ggml_metal_kargs_mul_mm_id & args, - device const char * src0s, - device const char * src1, - device char * dst, - device const char * ids, - threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiitg[[thread_index_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { - - const int32_t i02 = tgpig.z; - - tgpig.z = 0; - - device const char * src0 = src0s + i02*args.nb02; - - // row indices - threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192); - - // TODO: parallelize this loop - int32_t _ne1 = 0; - for (ushort ii1 = 0; ii1 < args.nei1; ii1++) { - for (ushort ii0 = 0; ii0 < args.nei0; ii0++) { - int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0]; - if (id == i02) { - if (tiitg == 0) { - rowids[_ne1] = ushort2(ii0, ii1); - } - _ne1++; - } - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - kernel_mul_mm_id_impl( - args.ne00, - args.ne02, - args.nb01, - args.nb02, - args.ne11, - args.ne12, - args.nb10, - args.nb11, - args.nb12, - args.ne0, - _ne1, - (int64_t)args.ne0*args.ne1, - src0, - src1, - rowids, - dst, - shmem, - tgpig, - tiitg, - sgitg); -} - #define QK_NL 16 // @@ -6542,63 +6811,64 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get // matrix-matrix multiplication // -typedef decltype(kernel_mul_mm) mat_mm_t; +typedef decltype(kernel_mul_mm) mul_mm_t; -template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm; #endif -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm; // // indirect matrix-matrix multiplication // -typedef decltype(kernel_mul_mm_id) mat_mm_id_t; +typedef decltype(kernel_mul_mm_id) mul_mm_id; -template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id; #if defined(GGML_METAL_USE_BF16) -template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id; #endif -template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; -template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id; + // // matrix-vector multiplication diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index 92f05d555..971314deb 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -27,12 +27,15 @@ if (MUSAToolkit_FOUND) file(GLOB GGML_HEADERS_MUSA "../ggml-cuda/*.cuh") list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h") + list(APPEND GGML_HEADERS_MUSA "../ggml-musa/mudnn.cuh") file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu") file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_MUSA ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu") list(APPEND GGML_SOURCES_MUSA ${SRCS}) + file(GLOB SRCS "../ggml-musa/*.cu") + list(APPEND GGML_SOURCES_MUSA ${SRCS}) if (GGML_CUDA_FA_ALL_QUANTS) file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu") @@ -62,7 +65,9 @@ if (MUSAToolkit_FOUND) ) # TODO: do not use CUDA definitions for MUSA - target_compile_definitions(ggml PUBLIC GGML_USE_CUDA) + if (NOT GGML_BACKEND_DL) + target_compile_definitions(ggml PUBLIC GGML_USE_CUDA) + endif() add_compile_definitions(GGML_USE_MUSA) add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE}) @@ -92,9 +97,10 @@ if (MUSAToolkit_FOUND) endif() if (GGML_STATIC) + # TODO: mudnn has not provided static libraries yet target_link_libraries(ggml-musa PRIVATE MUSA::musart_static MUSA::mublas_static) else() - target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas) + target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas mudnn) endif() if (GGML_CUDA_NO_VMM) diff --git a/ggml/src/ggml-musa/mudnn.cu b/ggml/src/ggml-musa/mudnn.cu new file mode 100644 index 000000000..020c1702c --- /dev/null +++ b/ggml/src/ggml-musa/mudnn.cu @@ -0,0 +1,112 @@ +#include +#include + +#include "mudnn.cuh" + +namespace mudnn = musa::dnn; + +// Returns a human-readable error string for mudnn::Status +const char* mudnnGetErrorString(mudnn::Status err) { + switch (err) { + case mudnn::Status::SUCCESS: + return "Success"; + case mudnn::Status::INVALID_PARAMETER: + return "Invalid parameter"; + case mudnn::Status::NOT_INITIALIZED: + return "Not initialized"; + case mudnn::Status::ALLOC_FAILED: + return "Allocation failed"; + case mudnn::Status::NOT_SUPPORTED: + return "Not supported"; + case mudnn::Status::INTERNAL_ERROR: + return "Internal error"; + case mudnn::Status::ARCH_MISMATCH: + return "Architecture mismatch"; + case mudnn::Status::EXECUTION_FAILED: + return "Execution failed"; + default: + return "Unknown mudnn status"; + } +} + +// Error checking macro for MUDNN calls +#define MUDNN_CHECK(err) CUDA_CHECK_GEN(err, mudnn::Status::SUCCESS, mudnnGetErrorString) + +namespace { + // Thread-safe cache for mudnn::Handle objects per device + std::unordered_map> handle_cache; + std::mutex handle_cache_mutex; + + mudnn::Handle* get_cached_handle(int device_id) { + std::lock_guard lock(handle_cache_mutex); + auto it = handle_cache.find(device_id); + if (it != handle_cache.end()) { + return it->second.get(); + } + auto handle = std::make_unique(device_id); + mudnn::Handle* handle_ptr = handle.get(); + handle_cache[device_id] = std::move(handle); + return handle_ptr; + } +} + +// Extracts dimensions and strides from a ggml_tensor +int get_ggml_dims_and_strides(const ggml_tensor* tensor, + std::vector& dims, + std::vector& strides) { + const int ndims = ggml_n_dims(tensor); + const size_t element_size = ggml_element_size(tensor); + + dims.resize(ndims); + strides.resize(ndims); + + for (int i = 0; i < ndims; ++i) { + dims[i] = tensor->ne[i]; + strides[i] = tensor->nb[i] / static_cast(element_size); + } + return ndims; +} + +// Converts ggml_type to mudnn::Tensor::Type +mudnn::Tensor::Type ggml_type_to_mudnn_type(ggml_type type) { + switch (type) { + case GGML_TYPE_F32: + return mudnn::Tensor::Type::FLOAT; + case GGML_TYPE_F16: + return mudnn::Tensor::Type::HALF; + + // TODO: Add support for other types + + default: + MUDNN_CHECK(mudnn::Status::NOT_SUPPORTED); + } + + return mudnn::Tensor::Type::FLOAT; // Default fallback +} + +// Asynchronous memory copy using mudnn::Unary::IDENTITY +musaError_t mudnnMemcpyAsync(ggml_backend_cuda_context& ctx, const ggml_tensor* dst, const ggml_tensor* src) { + mudnn::Tensor tensor_dst, tensor_src; + + MUDNN_CHECK(tensor_dst.SetType(ggml_type_to_mudnn_type(dst->type))); + MUDNN_CHECK(tensor_src.SetType(ggml_type_to_mudnn_type(src->type))); + + std::vector dims, strides; + const int ndims = get_ggml_dims_and_strides(src, dims, strides); + + MUDNN_CHECK(tensor_dst.SetNdInfo(ndims, dims.data(), strides.data())); + MUDNN_CHECK(tensor_src.SetNdInfo(ndims, dims.data(), strides.data())); + MUDNN_CHECK(tensor_dst.SetAddr(dst->data)); + MUDNN_CHECK(tensor_src.SetAddr(src->data)); + + mudnn::Unary op; + MUDNN_CHECK(op.SetMode(mudnn::Unary::Mode::IDENTITY)); + MUDNN_CHECK(op.SetAlpha(0.0f)); + MUDNN_CHECK(op.SetBeta(0.0f)); + + mudnn::Handle* handle = get_cached_handle(ctx.device); + MUDNN_CHECK(handle->SetStream(ctx.stream())); + MUDNN_CHECK(op.Run(*handle, tensor_dst, tensor_src)); + + return musaSuccess; +} diff --git a/ggml/src/ggml-musa/mudnn.cuh b/ggml/src/ggml-musa/mudnn.cuh new file mode 100644 index 000000000..a63be5755 --- /dev/null +++ b/ggml/src/ggml-musa/mudnn.cuh @@ -0,0 +1,12 @@ +#pragma once + +#include "../include/ggml.h" +#include "../ggml-cuda/common.cuh" + +// Asynchronously copies data from src tensor to dst tensor using the provided context. +// Returns a musaError_t indicating success or failure. +musaError_t mudnnMemcpyAsync( + ggml_backend_cuda_context &ctx, + const ggml_tensor *dst, + const ggml_tensor *src +); diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 624cb1b9d..0e2a41964 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -54,16 +54,54 @@ function(ggml_opencl_add_kernel KNAME) endfunction() set(GGML_OPENCL_KERNELS - ggml-opencl - ggml-opencl_mm - ggml-opencl_cvt - ggml-opencl_gemv_noshuffle - ggml-opencl_gemv_noshuffle_general - ggml-opencl_mul_mat_Ab_Bi_8x4 - ggml-opencl_transpose_16 - ggml-opencl_transpose_32 - ggml-opencl_transpose_32_16 - ggml-opencl_im2col + add + argsort + clamp + cpy + cvt + diag_mask_inf + div + gelu + gemv_noshuffle_general + gemv_noshuffle + get_rows + group_norm + im2col_f32 + im2col_f16 + mul_mat_Ab_Bi_8x4 + mul_mv_f16_f16 + mul_mv_f16_f32_1row + mul_mv_f16_f32_l4 + mul_mv_f16_f32 + mul_mv_f32_f32 + mul_mv_q4_0_f32 + mul_mv_q4_0_f32_v + mul_mv_q4_0_f32_8x_flat + mul_mv_q4_0_f32_1d_8x_flat + mul_mv_q4_0_f32_1d_16x_flat + mul_mv_q6_k + mul_mv_id_q4_0_f32_8x_flat + mul + norm + relu + rms_norm + rope + scale + sigmoid + silu + softmax_4_f32 + softmax_4_f16 + softmax_f32 + softmax_f16 + sub + sum_rows + transpose + concat + tsembd + upscale + tanh + pad + repeat ) foreach (K ${GGML_OPENCL_KERNELS}) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index b8b5cbd3e..628e574f0 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #undef MIN #undef MAX @@ -64,11 +65,42 @@ enum ADRENO_GPU_GEN { X1E, }; +enum ADRENO_CL_COMPILER_TYPE { + E031, + DX, +}; + struct ggml_cl_version { cl_uint major = 0; cl_uint minor = 0; }; + +struct ggml_cl_compiler_version { + ADRENO_CL_COMPILER_TYPE type; + int major = -1; + int minor = -1; + int patch = -1; + + bool same(ADRENO_CL_COMPILER_TYPE t, int x, int y, int z) const { + return major == x && minor == y && patch == z && type == t; + } + bool newer_than(ADRENO_CL_COMPILER_TYPE t, int x, int y, int z) const { + return major*10000 + minor*100 + patch > x*10000 + y*100 + z && type == t; + } + bool newer_than_or_same(ADRENO_CL_COMPILER_TYPE t, int x, int y, int z) const { + return same(t, x, y, z) || newer_than(t, x, y, z); + } +}; + +static size_t align_to(size_t value, size_t to_alignment) { + GGML_ASSERT(to_alignment && "Invalid alignment (must be non-zero)"); + GGML_ASSERT((to_alignment & (to_alignment - 1)) == 0 && "to_alignment must be power-of-two"); + + return ((value + to_alignment - 1) / to_alignment) * to_alignment; +} + + // Parses a version string of form "XX.YY ". On an error returns ggml_cl_version with all zeroes. static ggml_cl_version parse_cl_version(std::string_view str) { size_t major_str_begin = 0; @@ -173,33 +205,51 @@ static ADRENO_GPU_GEN get_adreno_gpu_gen(const char *device_name) { return ADRENO_GPU_GEN::ADRENO_UNKNOWN; } -static int get_adreno_cl_compiler_version(const char *driver_version) { +static ggml_cl_compiler_version get_adreno_cl_compiler_version(const char *driver_version) { std::string driver_ver_str(driver_version); + ADRENO_CL_COMPILER_TYPE type = ADRENO_CL_COMPILER_TYPE::E031; size_t compiler_ver_pos = driver_ver_str.find("E031"); size_t compiler_ver_len = 13; - size_t compiler_ver_offset = 5; + size_t compiler_major_offset = 5; + size_t compiler_minor_offset = 8; + size_t compiler_patch_offset = 11; if (compiler_ver_pos == std::string::npos) { compiler_ver_pos = driver_ver_str.find("DX"); if (compiler_ver_pos == std::string::npos) { - return -1; + return {}; } + type = ADRENO_CL_COMPILER_TYPE::DX; compiler_ver_len = 11; - compiler_ver_offset = 3; + compiler_major_offset = 3; } std::string compiler_ver_str = driver_ver_str.substr(compiler_ver_pos, compiler_ver_len); - std::string major_ver_str = compiler_ver_str.substr(compiler_ver_offset, 2); - return std::atoi(major_ver_str.c_str()); + int major = std::atoi(compiler_ver_str.substr(compiler_major_offset, 2).c_str()); + int minor = std::atoi(compiler_ver_str.substr(compiler_minor_offset, 2).c_str()); + int patch = std::atoi(compiler_ver_str.substr(compiler_patch_offset, 2).c_str()); + return { type, major, minor, patch }; } +struct ggml_backend_opencl_context; + // backend device context struct ggml_backend_opencl_device_context { cl_platform_id platform; std::string platform_name; - cl_device_id device; - std::string device_name; + cl_device_id device; + std::string device_name; + cl_device_type device_type; + std::string device_version; + + // Initialized by ggml_cl2_init(). + ggml_backend_opencl_context * backend_ctx = nullptr; + + // Initialized by ggml_backend_opencl_device_get_buffer_type() + ggml_backend_buffer_type buffer_type; + + cl_context context = nullptr; }; // backend context @@ -215,27 +265,78 @@ struct ggml_backend_opencl_context { cl_int alignment; size_t max_alloc_size; bool fp16_support; + bool has_vector_subgroup_broadcast; + ggml_cl_compiler_version adreno_cl_compiler_version; int adreno_wave_size; + cl_bool non_uniform_workgroups; + cl_context context; cl_command_queue queue; - cl_program program; - cl_program program_1; - cl_program program_2; - cl_program program_im2col; + cl_program program_add; + cl_program program_clamp; + cl_program program_cpy; + cl_program program_cvt; + cl_program program_diag_mask_inf; + cl_program program_gelu; + cl_program program_gemv_noshuffle_general; + cl_program program_gemv_noshuffle; + cl_program program_get_rows; + cl_program program_im2col_f16; + cl_program program_im2col_f32; + cl_program program_mul_mat_Ab_Bi_8x4; + cl_program program_mul_mv_q4_0_f32; + cl_program program_mul_mv_q4_0_f32_v; + cl_program program_mul_mv_q4_0_f32_8x_flat; + cl_program program_mul_mv_q4_0_f32_1d_8x_flat; + cl_program program_mul_mv_q4_0_f32_1d_16x_flat; + cl_program program_mul_mv_q6_K; + cl_program program_mul_mv_f16_f16; + cl_program program_mul_mv_f16_f32_1row; + cl_program program_mul_mv_f16_f32_l4; + cl_program program_mul_mv_f16_f32; + cl_program program_mul_mv_f32_f32; + cl_program program_mul; + cl_program program_div; + cl_program program_sub; + cl_program program_norm; + cl_program program_relu; + cl_program program_rms_norm; + cl_program program_group_norm; + cl_program program_rope; + cl_program program_scale; + cl_program program_silu; + cl_program program_sigmoid; + cl_program program_softmax_f32; + cl_program program_softmax_f16; + cl_program program_softmax_4_f32; + cl_program program_softmax_4_f16; + cl_program program_argsort_f32_i32; + cl_program program_sum_rows_f32; + cl_program program_repeat; + cl_program program_pad; + cl_program program_tanh; + cl_program program_upscale; + cl_program program_concat; + cl_program program_tsembd; + cl_program program_mul_mv_id_q4_0_f32_8x_flat; cl_kernel kernel_add, kernel_add_row; cl_kernel kernel_mul, kernel_mul_row; + cl_kernel kernel_div, kernel_div_row; + cl_kernel kernel_sub, kernel_sub_row; cl_kernel kernel_scale; cl_kernel kernel_silu, kernel_silu_4; cl_kernel kernel_gelu, kernel_gelu_4; cl_kernel kernel_gelu_quick, kernel_gelu_quick_4; cl_kernel kernel_relu; + cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16; cl_kernel kernel_clamp; cl_kernel kernel_norm; cl_kernel kernel_rms_norm; + cl_kernel kernel_group_norm; cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8; cl_kernel kernel_soft_max, kernel_soft_max_4; cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16; @@ -249,19 +350,29 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mat_f16_f32; cl_kernel kernel_mul_mat_f16_f32_l4; cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v; - cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0, kernel_mul_mat_q4_0_f32_flat; + cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0; cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; - cl_kernel kernel_convert_block_q4_0_noshuffle, kernel_mul_mat_q4_0_f32_flat_v0, - kernel_mul_mat_q4_0_f32_flat_img_v0; + cl_kernel kernel_convert_block_q4_0_noshuffle; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; cl_kernel kernel_mul_mv_q6_K_f32; cl_kernel kernel_im2col_f32, kernel_im2col_f16; + cl_kernel kernel_argsort_f32_i32; + cl_kernel kernel_sum_rows_f32; + cl_kernel kernel_repeat; + cl_kernel kernel_pad; + cl_kernel kernel_tanh_f32_nd; + cl_kernel kernel_tanh_f16_nd; + cl_kernel kernel_upscale; + cl_kernel kernel_upscale_bilinear; + cl_kernel kernel_concat_f32_contiguous; + cl_kernel kernel_concat_f32_non_contiguous; + cl_kernel kernel_timestep_embedding; + cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS // Transpose kernels - cl_program program_transpose_32; - cl_program program_transpose_32_16; - cl_program program_transpose_16; + cl_program program_transpose; + cl_kernel kernel_transpose_32; cl_kernel kernel_transpose_32_16; cl_kernel kernel_transpose_16; @@ -286,15 +397,8 @@ struct ggml_backend_opencl_context { #endif // GGML_OPENCL_USE_ADRENO_KERNELS }; -static ggml_backend_device g_ggml_backend_opencl_device; -static ggml_backend_opencl_device_context g_ggml_ctx_dev_main { - /*.platform =*/ nullptr, - /*.platform_nane =*/ "", - /*.device =*/ nullptr, - /*.device_name =*/ "", -}; - -static int ggml_backend_opencl_n_devices = 0; +// All registered devices with a default device in the front. +static std::vector g_ggml_backend_opencl_devices; // Profiling #ifdef GGML_OPENCL_PROFILING @@ -374,25 +478,953 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co return p; } -static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { - static bool initialized = false; - static ggml_backend_opencl_context *backend_ctx = nullptr; +static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_version opencl_c_version) { + cl_int err; - if (initialized) { - return backend_ctx; + // compiler options for general kernels + auto opencl_c_std = + std::string("CL") + std::to_string(opencl_c_version.major) + "." + std::to_string(opencl_c_version.minor); + std::string compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable -cl-unsafe-math-optimizations" + " -cl-finite-math-only -cl-fast-relaxed-math"; + + GGML_LOG_INFO("ggml_opencl: loading OpenCL kernels"); + + // add + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "add.cl.h" + }; +#else + const std::string kernel_src = read_file("add.cl"); +#endif + backend_ctx->program_add = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_add = clCreateKernel(backend_ctx->program_add, "kernel_add", &err), err)); + CL_CHECK((backend_ctx->kernel_add_row = clCreateKernel(backend_ctx->program_add, "kernel_add_row", &err), err)); + GGML_LOG_CONT("."); } - ggml_backend_opencl_device_context *dev_ctx = (ggml_backend_opencl_device_context *)dev->context; - GGML_ASSERT(dev_ctx); - GGML_ASSERT(dev_ctx->platform == nullptr); - GGML_ASSERT(dev_ctx->device == nullptr); - GGML_ASSERT(backend_ctx == nullptr); + // clamp + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "clamp.cl.h" + }; +#else + const std::string kernel_src = read_file("clamp.cl"); +#endif + backend_ctx->program_clamp = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - initialized = true; - backend_ctx = new ggml_backend_opencl_context(); - backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN; + CL_CHECK((backend_ctx->kernel_clamp = clCreateKernel(backend_ctx->program_clamp, "kernel_clamp", &err), err)); + GGML_LOG_CONT("."); + } - cl_int err; + // cpy + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "cpy.cl.h" + }; +#else + const std::string kernel_src = read_file("cpy.cl"); +#endif + backend_ctx->program_cpy = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f16_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f16_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f32_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f32_f32 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f32_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // cvt + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "cvt.cl.h" + }; +#else + const std::string kernel_src = read_file("cvt.cl"); +#endif + backend_ctx->program_cvt = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err)); + GGML_LOG_CONT("."); + } + + // diag_mask_inf + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "diag_mask_inf.cl.h" + }; +#else + const std::string kernel_src = read_file("diag_mask_inf.cl"); +#endif + backend_ctx->program_diag_mask_inf = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_diag_mask_inf_8 = clCreateKernel(backend_ctx->program_diag_mask_inf, "kernel_diag_mask_inf_8", &err), err)); + CL_CHECK((backend_ctx->kernel_diag_mask_inf = clCreateKernel(backend_ctx->program_diag_mask_inf, "kernel_diag_mask_inf", &err), err)); + GGML_LOG_CONT("."); + } + + // gelu + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gelu.cl.h" + }; +#else + const std::string kernel_src = read_file("gelu.cl"); +#endif + backend_ctx->program_gelu = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_gelu = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_4", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu_quick = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_quick", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu_quick_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_quick_4", &err), err)); + GGML_LOG_CONT("."); + } + + // get_rows + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "get_rows.cl.h" + }; +#else + const std::string kernel_src = read_file("get_rows.cl"); +#endif + backend_ctx->program_get_rows = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_get_rows_f32 = clCreateKernel(backend_ctx->program_get_rows, "kernel_get_rows_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_get_rows_f16 = clCreateKernel(backend_ctx->program_get_rows, "kernel_get_rows_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_get_rows_q4_0 = clCreateKernel(backend_ctx->program_get_rows, "kernel_get_rows_q4_0", &err), err)); + GGML_LOG_CONT("."); + } + + // im2col_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "im2col_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("im2col_f32.cl"); +#endif + backend_ctx->program_im2col_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_im2col_f32 = clCreateKernel(backend_ctx->program_im2col_f32, "kernel_im2col_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // im2col_f16 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "im2col_f16.cl.h" + }; +#else + const std::string kernel_src = read_file("im2col_f16.cl"); +#endif + backend_ctx->program_im2col_f16 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_im2col_f16 = clCreateKernel(backend_ctx->program_im2col_f16, "kernel_im2col_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_0_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_0_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_0_f32.cl"); +#endif + backend_ctx->program_mul_mv_q4_0_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32 = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32, "kernel_mul_mat_q4_0_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_0_f32_v + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_0_f32_v.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_0_f32_v.cl"); +#endif + backend_ctx->program_mul_mv_q4_0_f32_v = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_v = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32_v, "kernel_mul_mat_q4_0_f32_v", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_0_f32_8x_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_0_f32_8x_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_0_f32_8x_flat.cl"); +#endif + backend_ctx->program_mul_mv_q4_0_f32_8x_flat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32_8x_flat, "kernel_mul_mat_q4_0_f32_8x_flat", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_0_f32_1d_8x_flat + // This kernel does not compiler on Adreno cl compiler 38.01. Skip it for + // those compiler versions since it is anyway not used for Adreno. + if (backend_ctx->gpu_family != ADRENO || + backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) || + backend_ctx->adreno_cl_compiler_version.type == DX) { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_0_f32_1d_8x_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_0_f32_1d_8x_flat.cl"); +#endif + backend_ctx->program_mul_mv_q4_0_f32_1d_8x_flat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_8x_flat = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32_1d_8x_flat, "kernel_mul_mat_q4_0_f32_1d_8x_flat", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_0_f32_1d_16x_flat + // This kernel does not compiler on Adreno cl compiler 38.01. Skip it for + // those compiler versions since it is anyway not used for Adreno. + if (backend_ctx->gpu_family != ADRENO || + backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) || + backend_ctx->adreno_cl_compiler_version.type == DX) { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_0_f32_1d_16x_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_0_f32_1d_16x_flat.cl"); +#endif + backend_ctx->program_mul_mv_q4_0_f32_1d_16x_flat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_16x_flat = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32_1d_16x_flat, "kernel_mul_mat_q4_0_f32_1d_16x_flat", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q6_k + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q6_k.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q6_k.cl"); +#endif + backend_ctx->program_mul_mv_q6_K = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q6_K_f32 = clCreateKernel(backend_ctx->program_mul_mv_q6_K, "kernel_mul_mv_q6_K_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_f16_f16 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_f16_f16.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_f16_f16.cl"); +#endif + backend_ctx->program_mul_mv_f16_f16 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f16 = clCreateKernel(backend_ctx->program_mul_mv_f16_f16, "kernel_mul_mat_f16_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_f16_f32_1row + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_f16_f32_1row.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_f16_f32_1row.cl"); +#endif + backend_ctx->program_mul_mv_f16_f32_1row = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_1row = clCreateKernel(backend_ctx->program_mul_mv_f16_f32_1row, "kernel_mul_mat_f16_f32_1row", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_f16_f32_l4 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_f16_f32_l4.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_f16_f32_l4.cl"); +#endif + backend_ctx->program_mul_mv_f16_f32_l4 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_l4 = clCreateKernel(backend_ctx->program_mul_mv_f16_f32_l4, "kernel_mul_mat_f16_f32_l4", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_f16_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_f16_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_f16_f32.cl"); +#endif + backend_ctx->program_mul_mv_f16_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32 = clCreateKernel(backend_ctx->program_mul_mv_f16_f32, "kernel_mul_mat_f16_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_f32_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_f32_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_f32_f32.cl"); +#endif + backend_ctx->program_mul_mv_f32_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_f32_f32 = clCreateKernel(backend_ctx->program_mul_mv_f32_f32, "kernel_mul_mat_f32_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // mul + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul.cl.h" + }; +#else + const std::string kernel_src = read_file("mul.cl"); +#endif + backend_ctx->program_mul = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul = clCreateKernel(backend_ctx->program_mul, "kernel_mul", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_row = clCreateKernel(backend_ctx->program_mul, "kernel_mul_row", &err), err)); + GGML_LOG_CONT("."); + } + + // norm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "norm.cl.h" + }; +#else + const std::string kernel_src = read_file("norm.cl"); +#endif + backend_ctx->program_norm = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program_norm, "kernel_norm", &err), err)); + GGML_LOG_CONT("."); + } + + // relu + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "relu.cl.h" + }; +#else + const std::string kernel_src = read_file("relu.cl"); +#endif + backend_ctx->program_relu = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_relu = clCreateKernel(backend_ctx->program_relu, "kernel_relu", &err), err)); + GGML_LOG_CONT("."); + } + + // rms_norm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "rms_norm.cl.h" + }; +#else + const std::string kernel_src = read_file("rms_norm.cl"); +#endif + backend_ctx->program_rms_norm = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err)); + GGML_LOG_CONT("."); + } + + // rope + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "rope.cl.h" + }; +#else + const std::string kernel_src = read_file("rope.cl"); +#endif + backend_ctx->program_rope = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_rope_norm_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_norm_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_norm_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_norm_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_neox_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_neox_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_neox_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_neox_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_multi_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_multi_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_multi_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_multi_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_vision_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_vision_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_vision_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_vision_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // scale + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "scale.cl.h" + }; +#else + const std::string kernel_src = read_file("scale.cl"); +#endif + backend_ctx->program_scale = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_scale = clCreateKernel(backend_ctx->program_scale, "kernel_scale", &err), err)); + GGML_LOG_CONT("."); + } + + // silu + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "silu.cl.h" + }; +#else + const std::string kernel_src = read_file("silu.cl"); +#endif + backend_ctx->program_silu = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_silu = clCreateKernel(backend_ctx->program_silu, "kernel_silu", &err), err)); + CL_CHECK((backend_ctx->kernel_silu_4 = clCreateKernel(backend_ctx->program_silu, "kernel_silu_4", &err), err)); + GGML_LOG_CONT("."); + } + + // softmax_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "softmax_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("softmax_f32.cl"); +#endif + backend_ctx->program_softmax_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_soft_max = clCreateKernel(backend_ctx->program_softmax_f32, "kernel_soft_max", &err), err)); + GGML_LOG_CONT("."); + } + + // softmax_f16 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "softmax_f16.cl.h" + }; +#else + const std::string kernel_src = read_file("softmax_f16.cl"); +#endif + backend_ctx->program_softmax_f16 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_soft_max_f16 = clCreateKernel(backend_ctx->program_softmax_f16, "kernel_soft_max_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // softmax_4_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "softmax_4_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("softmax_4_f32.cl"); +#endif + backend_ctx->program_softmax_4_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_soft_max_4 = clCreateKernel(backend_ctx->program_softmax_4_f32, "kernel_soft_max_4", &err), err)); + GGML_LOG_CONT("."); + } + + // softmax_4_f16 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "softmax_4_f16.cl.h" + }; +#else + const std::string kernel_src = read_file("softmax_4_f16.cl"); +#endif + backend_ctx->program_softmax_4_f16 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_soft_max_4_f16 = clCreateKernel(backend_ctx->program_softmax_4_f16, "kernel_soft_max_4_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // argsort + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "argsort.cl.h" + }; +#else + const std::string kernel_src = read_file("argsort.cl"); +#endif + backend_ctx->program_argsort_f32_i32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_argsort_f32_i32 = clCreateKernel(backend_ctx->program_argsort_f32_i32, "kernel_argsort_f32_i32", &err), err)); + GGML_LOG_CONT("."); + } + + // div + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "div.cl.h" + }; +#else + const std::string kernel_src = read_file("div.cl"); +#endif + backend_ctx->program_div = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_div = clCreateKernel(backend_ctx->program_div, "kernel_div", &err), err)); + CL_CHECK((backend_ctx->kernel_div_row = clCreateKernel(backend_ctx->program_div, "kernel_div_row", &err), err)); + GGML_LOG_CONT("."); + } + + // sub + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "sub.cl.h" + }; +#else + const std::string kernel_src = read_file("sub.cl"); +#endif + backend_ctx->program_sub = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_sub = clCreateKernel(backend_ctx->program_sub, "kernel_sub", &err), err)); + CL_CHECK((backend_ctx->kernel_sub_row = clCreateKernel(backend_ctx->program_sub, "kernel_sub_row", &err), err)); + GGML_LOG_CONT("."); + } + + // sum_rows + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "sum_rows.cl.h" + }; +#else + const std::string kernel_src = read_file("sum_rows.cl"); +#endif + backend_ctx->program_sum_rows_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_sum_rows_f32 = clCreateKernel(backend_ctx->program_sum_rows_f32, "kernel_sum_rows_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // sigmoid + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "sigmoid.cl.h" + }; +#else + const std::string kernel_src = read_file("sigmoid.cl"); +#endif + backend_ctx->program_sigmoid = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_sigmoid_f32 = clCreateKernel(backend_ctx->program_sigmoid, "kernel_sigmoid_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_sigmoid_f16 = clCreateKernel(backend_ctx->program_sigmoid, "kernel_sigmoid_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // group_norm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "group_norm.cl.h" + }; +#else + const std::string kernel_src = read_file("group_norm.cl"); +#endif + backend_ctx->program_group_norm = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_group_norm = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm", &err), err)); + GGML_LOG_CONT("."); + } + + // repeat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "repeat.cl.h" + }; +#else + const std::string kernel_src = read_file("repeat.cl"); +#endif + if (!kernel_src.empty()) { + backend_ctx->program_repeat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_repeat = clCreateKernel(backend_ctx->program_repeat, "kernel_repeat", &err), err)); + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: repeat kernel source not found or empty. Repeat operations will not be available.\n"); + backend_ctx->program_repeat = nullptr; + backend_ctx->kernel_repeat = nullptr; + } + } + + // pad + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "pad.cl.h" + }; +#else + const std::string kernel_src = read_file("pad.cl"); +#endif + if (!kernel_src.empty()) { + backend_ctx->program_pad = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_pad = clCreateKernel(backend_ctx->program_pad, "kernel_pad", &err), err)); + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: pad kernel source not found or empty. Pad operations will not be available.\n"); + backend_ctx->program_pad = nullptr; + backend_ctx->kernel_pad = nullptr; + } + } + + // tanh + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "tanh.cl.h" + }; +#else + const std::string kernel_src = read_file("tanh.cl"); +#endif + if (!kernel_src.empty()) { + backend_ctx->program_tanh = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_tanh_f32_nd = clCreateKernel(backend_ctx->program_tanh, "kernel_tanh_f32_nd", &err), err)); + CL_CHECK((backend_ctx->kernel_tanh_f16_nd = clCreateKernel(backend_ctx->program_tanh, "kernel_tanh_f16_nd", &err), err)); + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: tanh kernel source not found or empty. Tanh operation will not be available.\n"); + backend_ctx->program_tanh = nullptr; + backend_ctx->kernel_tanh_f32_nd = nullptr; + backend_ctx->kernel_tanh_f16_nd = nullptr; + } + } + + // upscale + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "upscale.cl.h" + }; +#else + const std::string kernel_src = read_file("upscale.cl"); +#endif + if (!kernel_src.empty()) { + backend_ctx->program_upscale = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_upscale = clCreateKernel(backend_ctx->program_upscale, "kernel_upscale", &err), err)); + if (backend_ctx->program_upscale) { + cl_int err_bilinear; + backend_ctx->kernel_upscale_bilinear = clCreateKernel(backend_ctx->program_upscale, "kernel_upscale_bilinear", &err_bilinear); + if (err_bilinear != CL_SUCCESS) { + GGML_LOG_WARN("ggml_opencl: kernel_upscale_bilinear not found in upscale.cl. Bilinear upscale will not be available. Error: %d\n", err_bilinear); + backend_ctx->kernel_upscale_bilinear = nullptr; + } + } else { + backend_ctx->kernel_upscale_bilinear = nullptr; + } + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: upscale kernel source not found or empty. Upscale operations will not be available.\n"); + backend_ctx->program_upscale = nullptr; + backend_ctx->kernel_upscale = nullptr; + backend_ctx->kernel_upscale_bilinear = nullptr; + } + } + + // concat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "concat.cl.h" + }; +#else + + const std::string kernel_src = read_file("concat.cl"); +#endif + if (!kernel_src.empty()) { + backend_ctx->program_concat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_concat_f32_contiguous = clCreateKernel(backend_ctx->program_concat, "kernel_concat_f32_contiguous", &err), err)); + CL_CHECK((backend_ctx->kernel_concat_f32_non_contiguous = clCreateKernel(backend_ctx->program_concat, "kernel_concat_f32_non_contiguous", &err), err)); + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: concat kernel source not found or empty. Concat operations will not be available.\n"); + backend_ctx->program_concat = nullptr; + backend_ctx->kernel_concat_f32_contiguous = nullptr; + backend_ctx->kernel_concat_f32_non_contiguous = nullptr; + } + } + + // timestep_embedding + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "tsembd.cl.h" + }; +#else + + const std::string kernel_src = read_file("tsembd.cl"); +#endif + if (!kernel_src.empty()) { + backend_ctx->program_tsembd = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_timestep_embedding = clCreateKernel(backend_ctx->program_tsembd, "kernel_timestep_embedding", &err), err)); + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: timestep_embedding kernel source not found or empty. This op will not be available.\n"); + backend_ctx->program_tsembd = nullptr; + backend_ctx->kernel_timestep_embedding = nullptr; + } + } + + // mul_mv_id_q4_0_f32_8x_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_id_q4_0_f32_8x_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_id_q4_0_f32_8x_flat.cl"); +#endif + backend_ctx->program_mul_mv_id_q4_0_f32_8x_flat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_id_q4_0_f32_8x_flat = clCreateKernel(backend_ctx->program_mul_mv_id_q4_0_f32_8x_flat, "kernel_mul_mv_id_q4_0_f32_8x_flat", &err), err)); + GGML_LOG_CONT("."); + } + + // Adreno kernels +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // transpose + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "transpose.cl.h" + }; +#else + const std::string kernel_src = read_file("transpose.cl"); +#endif + backend_ctx->program_transpose = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_transpose_32_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32_16", &err), err)); + CL_CHECK((backend_ctx->kernel_transpose_32 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32", &err), err)); + CL_CHECK((backend_ctx->kernel_transpose_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16", &err), err)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle_general + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemv_general { + #include "gemv_noshuffle_general.cl.h" + }; +#else + const std::string kernel_src_CL_gemv_general = read_file("gemv_noshuffle_general.cl"); +#endif + + backend_ctx->program_CL_gemv_general = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts); + + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general = clCreateKernel(backend_ctx->program_CL_gemv_general, "kernel_gemv_noshuffle", &err), err)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle + { + // Gemv 2048, 16384 + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DLINE_STRIDE_A=2048 " + " -DBLOCK_STRIDE_A=16384 " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemv { + #include "gemv_noshuffle.cl.h" + }; +#else + const std::string kernel_src_CL_gemv = read_file("gemv_noshuffle.cl"); +#endif + + backend_ctx->program_CL_gemv_4096_1_4096 = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_4096, "kernel_gemv_noshuffle", &err), err)); + GGML_LOG_CONT("."); + + // Gemv 2048, 16384 + CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DLINE_STRIDE_A=2048 " + " -DBLOCK_STRIDE_A=16384 " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + + backend_ctx->program_CL_gemv_4096_1_11008 = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_11008, "kernel_gemv_noshuffle", &err), err)); + GGML_LOG_CONT("."); + + // Gemv 5504, 44032 + CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DLINE_STRIDE_A=5504 " + " -DBLOCK_STRIDE_A=44032 " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + + backend_ctx->program_CL_gemv_11008_1_4096 = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_11008_1_4096, "kernel_gemv_noshuffle", &err), err)); + GGML_LOG_CONT("."); + + // Gemv 16000, 128000 + CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DLINE_STRIDE_A=16000 " + " -DBLOCK_STRIDE_A=128000 " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + + backend_ctx->program_CL_gemv_32000_1_4096 = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_32000_1_4096, "kernel_gemv_noshuffle", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mat_Ab_Bi_8x4 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemm { + #include "mul_mat_Ab_Bi_8x4.cl.h" + }; +#else + const std::string kernel_src_CL_gemm = read_file("mul_mat_Ab_Bi_8x4.cl"); +#endif + backend_ctx->program_CL_gemm = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_CL_gemm.c_str(), compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err)); + GGML_LOG_CONT("."); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + GGML_LOG_CONT("\n"); +} + +// XXX static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { +// XXX static bool initialized = false; +// XXX static ggml_backend_opencl_context *backend_ctx = nullptr; + +static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev); + +namespace /* anonymous */ { +extern struct ggml_backend_device_i ggml_backend_opencl_device_i; +} + +// Look for available and suitable devices. +static std::vector ggml_opencl_probe_devices(ggml_backend_reg * reg) { + std::vector found_devices; #ifdef GGML_OPENCL_PROFILING GGML_LOG_INFO("ggml_opencl: OpenCL profiling enabled\n"); @@ -425,11 +1457,12 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { struct cl_device devices[NDEV]; unsigned n_devices = 0; struct cl_device * default_device = NULL; + unsigned default_platform_number = 0; cl_platform_id platform_ids[NPLAT]; if (clGetPlatformIDs(NPLAT, platform_ids, &n_platforms) != CL_SUCCESS) { GGML_LOG_ERROR("ggml_opencl: plaform IDs not available.\n"); - return backend_ctx; + return found_devices; } for (unsigned i = 0; i < n_platforms; i++) { @@ -464,19 +1497,22 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { } if (default_device == NULL && p->default_device != NULL) { - default_device = p->default_device; + default_device = p->default_device; + default_platform_number = i; } } if (n_devices == 0) { GGML_LOG_ERROR("ggml_opencl: could find any OpenCL devices.\n"); - return backend_ctx; + return found_devices; } - char * user_platform_string = getenv("GGML_OPENCL_PLATFORM"); - char * user_device_string = getenv("GGML_OPENCL_DEVICE"); - int user_platform_number = -1; - int user_device_number = -1; + char * user_platform_string = getenv("GGML_OPENCL_PLATFORM"); + char * user_device_string = getenv("GGML_OPENCL_DEVICE"); + int user_platform_number = -1; + int user_device_number = -1; + cl_device * candidate_devices = nullptr; + unsigned n_candidate_devices = 0; unsigned n; if (user_platform_string != NULL && sscanf(user_platform_string, " %u", &n) == 1 && n < n_platforms) { @@ -491,12 +1527,11 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { GGML_LOG_ERROR("ggml_opencl: invalid device number %d\n", user_device_number); exit(1); } - default_device = &platform->devices[user_device_number]; + default_device = &platform->devices[user_device_number]; + candidate_devices = platform->devices; + n_candidate_devices = platform->n_devices; } else { - - struct cl_device * selected_devices = devices; - unsigned n_selected_devices = n_devices; - + // Choose a platform by matching a substring. if (user_platform_number == -1 && user_platform_string != NULL && user_platform_string[0] != 0) { for (unsigned i = 0; i < n_platforms; i++) { struct cl_platform * p = &platforms[i]; @@ -511,20 +1546,20 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { exit(1); } } - if (user_platform_number != -1) { - struct cl_platform * p = &platforms[user_platform_number]; - selected_devices = p->devices; - n_selected_devices = p->n_devices; - default_device = p->default_device; - if (n_selected_devices == 0) { - GGML_LOG_ERROR("ggml_opencl: selected platform '%s' does not have any devices.\n", p->name); - exit(1); - } + + int platform_idx = user_platform_number != -1 ? user_platform_number : default_platform_number; + struct cl_platform * p = &platforms[platform_idx]; + candidate_devices = p->devices; + n_candidate_devices = p->n_devices; + default_device = p->default_device; + if (n_candidate_devices == 0) { + GGML_LOG_ERROR("ggml_opencl: selected platform '%s' does not have any devices.\n", p->name); + exit(1); } if (user_device_number == -1 && user_device_string != NULL && user_device_string[0] != 0) { - for (unsigned i = 0; i < n_selected_devices; i++) { - struct cl_device * d = &selected_devices[i]; + for (unsigned i = 0; i < n_candidate_devices; i++) { + struct cl_device * d = &candidate_devices[i]; if (strstr(d->name, user_device_string) != NULL) { user_device_number = d->number; break; @@ -536,71 +1571,145 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { } } if (user_device_number != -1) { - selected_devices = &devices[user_device_number]; - n_selected_devices = 1; - default_device = &selected_devices[0]; + candidate_devices = &devices[user_device_number]; + n_candidate_devices = 1; + default_device = &candidate_devices[0]; } - GGML_ASSERT(n_selected_devices > 0); + GGML_ASSERT(n_candidate_devices > 0); if (default_device == NULL) { - default_device = &selected_devices[0]; + default_device = &candidate_devices[0]; } } - GGML_LOG_INFO("ggml_opencl: selecting platform: '%s'\n", default_device->platform->name); - GGML_LOG_INFO("ggml_opencl: selecting device: '%s (%s)'\n", default_device->name, default_device->version); - if (default_device->type != CL_DEVICE_TYPE_GPU) { - GGML_LOG_WARN("ggml_opencl: warning, not a GPU: '%s'.\n", default_device->name); + GGML_ASSERT(n_candidate_devices != 0 && candidate_devices); + + // Put the default device in front. + for (unsigned i = 1; i < n_candidate_devices; i++) { + if (&candidate_devices[i] == default_device) { + std::swap(candidate_devices[0], candidate_devices[i]); + default_device = &candidate_devices[0]; + break; + } } - dev_ctx->platform = default_device->platform->id; - dev_ctx->device = default_device->id; - backend_ctx->device = default_device->id; + GGML_LOG_INFO("ggml_opencl: selected platform: '%s'\n", default_device->platform->name); - if (strstr(default_device->name, "Adreno") || - strstr(default_device->name, "Qualcomm") || - strstr(default_device->version, "Adreno")) { + std::vector device_ids; + for (auto dev = candidate_devices, dev_end = candidate_devices + n_candidate_devices; dev != dev_end; dev++) { + device_ids.push_back(dev->id); + } + + cl_int err; + cl_context shared_context; + cl_context_properties properties[] = { (intptr_t) CL_CONTEXT_PLATFORM, (intptr_t) default_device->platform->id, 0 }; + + CL_CHECK( + (shared_context = clCreateContext(properties, device_ids.size(), device_ids.data(), NULL, NULL, &err), err)); + + for (auto dev = candidate_devices, dev_end = candidate_devices + n_candidate_devices; dev != dev_end; dev++) { + GGML_LOG_INFO("\nggml_opencl: device: '%s (%s)'\n", dev->name, dev->version); + + auto dev_ctx = std::unique_ptr(new ggml_backend_opencl_device_context{ + /*.platform =*/dev->platform->id, + /*.platform_nane =*/dev->platform->name, + /*.device =*/dev->id, + /*.device_name =*/dev->name, + /*.device_type =*/dev->type, + /*.device_version =*/dev->version, + /*.backend_ctx =*/nullptr, + /*.buffer_type =*/{}, + /*.context =*/shared_context, + }); + + found_devices.push_back(ggml_backend_device{ + /* .iface = */ ggml_backend_opencl_device_i, + /* .reg = */ reg, + /* .context = */ dev_ctx.get(), + }); + + if (!ggml_cl2_init(&found_devices.back())) { + found_devices.pop_back(); + GGML_LOG_INFO("ggml_opencl: drop unsupported device.\n"); + continue; + } + + dev_ctx.release(); + } + + if (found_devices.size()) { + auto * dev_ctx = static_cast(found_devices.front().context); + GGML_LOG_INFO("ggml_opencl: default device: '%s (%s)'\n", dev_ctx->device_name.c_str(), + dev_ctx->device_version.c_str()); + + if (dev_ctx->device_type != CL_DEVICE_TYPE_GPU) { + GGML_LOG_WARN("ggml_opencl: warning, the default device is not a GPU: '%s'.\n", + dev_ctx->device_name.c_str()); + } + } + + return found_devices; +} + +// Initialize device if it is supported (returns nullptr if it is not). +static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { + GGML_ASSERT(dev); + GGML_ASSERT(dev->context); + + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) dev->context; + GGML_ASSERT(dev_ctx->platform); + GGML_ASSERT(dev_ctx->device); + + if (dev_ctx->backend_ctx) { + return dev_ctx->backend_ctx; + } + + auto backend_ctx = std::make_unique(); + backend_ctx->device = dev_ctx->device; + backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN; + + if (strstr(dev_ctx->device_name.c_str(), "Adreno") || + strstr(dev_ctx->device_name.c_str(), "Qualcomm") || + strstr(dev_ctx->device_version.c_str(), "Adreno")) { backend_ctx->gpu_family = GPU_FAMILY::ADRENO; // Usually device version contains the detailed device name - backend_ctx->adreno_gen = get_adreno_gpu_gen(default_device->version); + backend_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_version.c_str()); if (backend_ctx->adreno_gen == ADRENO_GPU_GEN::ADRENO_UNKNOWN) { - backend_ctx->adreno_gen = get_adreno_gpu_gen(default_device->name); + backend_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_name.c_str()); } // Use wave size of 64 for all Adreno GPUs. backend_ctx->adreno_wave_size = 64; - } else if (strstr(default_device->name, "Intel")) { + } else if (strstr(dev_ctx->device_name.c_str(), "Intel")) { backend_ctx->gpu_family = GPU_FAMILY::INTEL; } else { - GGML_LOG_ERROR("Unsupported GPU: %s\n", default_device->name); + GGML_LOG_ERROR("Unsupported GPU: %s\n", dev_ctx->device_name.c_str()); backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN; - return backend_ctx; + return nullptr; } #ifdef GGML_OPENCL_USE_ADRENO_KERNELS if (backend_ctx->gpu_family != GPU_FAMILY::ADRENO) { GGML_LOG_ERROR("ggml_opencl: Adreno-specific kernels should not be enabled for non-Adreno GPUs; " "run on an Adreno GPU or recompile with CMake option `-DGGML_OPENCL_USE_ADRENO_KERNELS=OFF`\n"); - return backend_ctx; + return nullptr; } #endif // Populate backend device name - dev_ctx->platform_name = default_device->platform->name; - dev_ctx->device_name = default_device->name; - backend_ctx->device_name = default_device->name; + backend_ctx->device_name = dev_ctx->device_name; // A local ref of cl_device_id for convenience cl_device_id device = backend_ctx->device; - ggml_cl_version platform_version = get_opencl_platform_version(default_device->platform->id); + ggml_cl_version platform_version = get_opencl_platform_version(dev_ctx->platform); // Check device OpenCL version, OpenCL 2.0 or above is required ggml_cl_version opencl_c_version = get_opencl_c_version(platform_version, device); if (opencl_c_version.major < 2) { GGML_LOG_ERROR("ggml_opencl: OpenCL 2.0 or above is required\n"); - return backend_ctx; + return nullptr; } // Check driver version @@ -612,11 +1721,12 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { GGML_LOG_INFO("ggml_opencl: OpenCL driver: %s\n", driver_version); backend_ctx->driver_version = driver_version; - int adreno_cl_compiler_version = get_adreno_cl_compiler_version(driver_version); - bool has_vector_subgroup_broadcast = - adreno_cl_compiler_version >= 47 || adreno_cl_compiler_version == 17; + backend_ctx->adreno_cl_compiler_version = get_adreno_cl_compiler_version(driver_version); + backend_ctx->has_vector_subgroup_broadcast = + backend_ctx->adreno_cl_compiler_version.major >= 47 || + backend_ctx->adreno_cl_compiler_version.major == 17; GGML_LOG_INFO("ggml_opencl: vector subgroup broadcast support: %s\n", - has_vector_subgroup_broadcast ? "true" : "false"); + backend_ctx->has_vector_subgroup_broadcast ? "true" : "false"); size_t ext_str_size; clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, 0, NULL, &ext_str_size); @@ -630,7 +1740,7 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { // fp16 is required if (!backend_ctx->fp16_support) { GGML_LOG_ERROR("ggml_opencl: device does not support FP16\n"); - return backend_ctx; + return nullptr; } // If OpenCL 3.0 is supported, then check for cl_khr_subgroups, which becomes @@ -639,7 +1749,7 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { strstr(ext_buffer, "cl_intel_subgroups") == NULL) { GGML_LOG_ERROR("ggml_opencl: device does not support subgroups (cl_khr_subgroups or cl_intel_subgroups) " "(note that subgroups is an optional feature in OpenCL 3.0)\n"); - return backend_ctx; + return nullptr; } cl_uint base_align_in_bits; @@ -663,6 +1773,15 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { GGML_LOG_INFO("ggml_opencl: SVM atomics support: %s\n", svm_caps & CL_DEVICE_SVM_ATOMICS ? "true" : "false"); + if (opencl_c_version.major >= 3) { + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_NON_UNIFORM_WORK_GROUP_SUPPORT, sizeof(cl_bool), + &backend_ctx->non_uniform_workgroups, 0)); + } else { + GGML_ASSERT(opencl_c_version.major == 2); + // Non-uniform workgroup sizes is mandatory feature in v2.x. + backend_ctx->non_uniform_workgroups = true; + } + // Print out configurations #ifdef GGML_OPENCL_SOA_Q GGML_LOG_INFO("ggml_opencl: flattening quantized weights representation as struct of arrays (GGML_OPENCL_SOA_Q)\n"); @@ -672,14 +1791,10 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { GGML_LOG_INFO("ggml_opencl: using kernels optimized for Adreno (GGML_OPENCL_USE_ADRENO_KERNELS)\n"); #endif // GGML_OPENCL_USE_ADRENO_KERNELS - cl_context_properties properties[] = { - (intptr_t)CL_CONTEXT_PLATFORM, (intptr_t)dev_ctx->platform, 0 - }; - - CL_CHECK((backend_ctx->context = clCreateContext(properties, 1, &device, NULL, NULL, &err), err)); + cl_int err; // A local ref of cl_context for convenience - cl_context context = backend_ctx->context; + cl_context context = backend_ctx->context = dev_ctx->context; //CL_CHECK((queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err), // (err != CL_INVALID_QUEUE_PROPERTIES && err != CL_INVALID_VALUE ? err : @@ -691,247 +1806,10 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { #endif CL_CHECK((backend_ctx->queue = clCreateCommandQueue(context, device, command_queue_props, &err), err)); -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src { - #include "ggml-opencl.cl.h" - }; -#else - const std::string kernel_src = read_file("ggml-opencl.cl"); -#endif + // Load kernels + load_cl_kernels(backend_ctx.get(), opencl_c_version); - auto opencl_c_std = - std::string("CL") + std::to_string(opencl_c_version.major) + "." + std::to_string(opencl_c_version.minor); - - std::string compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable -cl-unsafe-math-optimizations" - " -cl-finite-math-only -cl-fast-relaxed-math"; - backend_ctx->program = build_program_from_source(context, device, kernel_src.c_str(), compile_opts); - - // Non matmul kernels. - CL_CHECK((backend_ctx->kernel_get_rows_f32 = clCreateKernel(backend_ctx->program, "kernel_get_rows_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_get_rows_f16 = clCreateKernel(backend_ctx->program, "kernel_get_rows_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_get_rows_q4_0 = clCreateKernel(backend_ctx->program, "kernel_get_rows_q4_0", &err), err)); - CL_CHECK((backend_ctx->kernel_add = clCreateKernel(backend_ctx->program, "kernel_add", &err), err)); - CL_CHECK((backend_ctx->kernel_add_row = clCreateKernel(backend_ctx->program, "kernel_add_row", &err), err)); - CL_CHECK((backend_ctx->kernel_mul = clCreateKernel(backend_ctx->program, "kernel_mul", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_row = clCreateKernel(backend_ctx->program, "kernel_mul_row", &err), err)); - CL_CHECK((backend_ctx->kernel_scale = clCreateKernel(backend_ctx->program, "kernel_scale", &err), err)); - CL_CHECK((backend_ctx->kernel_silu = clCreateKernel(backend_ctx->program, "kernel_silu", &err), err)); - CL_CHECK((backend_ctx->kernel_silu_4 = clCreateKernel(backend_ctx->program, "kernel_silu_4", &err), err)); - CL_CHECK((backend_ctx->kernel_gelu = clCreateKernel(backend_ctx->program, "kernel_gelu", &err), err)); - CL_CHECK((backend_ctx->kernel_gelu_4 = clCreateKernel(backend_ctx->program, "kernel_gelu_4", &err), err)); - CL_CHECK((backend_ctx->kernel_gelu_quick = clCreateKernel(backend_ctx->program, "kernel_gelu_quick", &err), err)); - CL_CHECK((backend_ctx->kernel_gelu_quick_4 = clCreateKernel(backend_ctx->program, "kernel_gelu_quick_4", &err), err)); - CL_CHECK((backend_ctx->kernel_relu = clCreateKernel(backend_ctx->program, "kernel_relu", &err), err)); - CL_CHECK((backend_ctx->kernel_clamp = clCreateKernel(backend_ctx->program, "kernel_clamp", &err), err)); - CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program, "kernel_norm", &err), err)); - CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program, "kernel_rms_norm", &err), err)); - CL_CHECK((backend_ctx->kernel_diag_mask_inf = clCreateKernel(backend_ctx->program, "kernel_diag_mask_inf", &err), err)); - CL_CHECK((backend_ctx->kernel_diag_mask_inf_8 = clCreateKernel(backend_ctx->program, "kernel_diag_mask_inf_8", &err), err)); - CL_CHECK((backend_ctx->kernel_soft_max = clCreateKernel(backend_ctx->program, "kernel_soft_max", &err), err)); - CL_CHECK((backend_ctx->kernel_soft_max_4 = clCreateKernel(backend_ctx->program, "kernel_soft_max_4", &err), err)); - CL_CHECK((backend_ctx->kernel_soft_max_f16 = clCreateKernel(backend_ctx->program, "kernel_soft_max_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_soft_max_4_f16 = clCreateKernel(backend_ctx->program, "kernel_soft_max_4_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_norm_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_norm_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_neox_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_neox_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_multi_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_multi_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_multi_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_multi_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_vision_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_vision_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_vision_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_vision_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(backend_ctx->program, "kernel_cpy_f32_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f32_f32 = clCreateKernel(backend_ctx->program, "kernel_cpy_f32_f32", &err), err)); - - // Matmul kernels. - CL_CHECK((backend_ctx->kernel_mul_mat_f32_f32 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f32_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_f16_f16 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_1row = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f32_1row", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_l4 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f32_l4", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_v = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32_v", &err), err)); - - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_flat = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32_flat", &err), err)); - CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program, "kernel_convert_block_q4_0", &err), err)); - CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program, "kernel_restore_block_q4_0", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32_8x_flat", &err), err)); - - // Load additional mulmat kernels. -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_1 { - #include "ggml-opencl_mm.cl.h" - }; -#else - const std::string kernel_src_1 = read_file("ggml-opencl_mm.cl"); -#endif - backend_ctx->program_1 = build_program_from_source(context, device, kernel_src_1.c_str(), compile_opts); - - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_8x_flat = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_1d_8x_flat", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_16x_flat = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_1d_16x_flat", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mv_q6_K_f32 = clCreateKernel(backend_ctx->program_1, "kernel_mul_mv_q6_K_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_flat_v0 = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_flat_v0", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_flat_img_v0 = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_flat_img_v0", &err), err)); - - // Load additional data conversion kernels. -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_2 { - #include "ggml-opencl_cvt.cl.h" - }; -#else - const std::string kernel_src_2 = read_file("ggml-opencl_cvt.cl"); -#endif - backend_ctx->program_2 = build_program_from_source(context, device, kernel_src_2.c_str(), compile_opts); - - CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_2, "kernel_convert_block_q4_0_noshuffle", &err), err)); - - // im2col kernels -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_im2col { - #include "ggml-opencl_im2col.cl.h" - }; -#else - const std::string kernel_src_im2col = read_file("ggml-opencl_im2col.cl"); -#endif - backend_ctx->program_im2col = build_program_from_source(context, device, kernel_src_im2col.c_str(), compile_opts); - - CL_CHECK((backend_ctx->kernel_im2col_f32 = clCreateKernel(backend_ctx->program_im2col, "kernel_im2col_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_im2col_f16 = clCreateKernel(backend_ctx->program_im2col, "kernel_im2col_f16", &err), err)); - - // Kernels for Adreno #ifdef GGML_OPENCL_USE_ADRENO_KERNELS -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string transpose_32_src { - #include "ggml-opencl_transpose_32.cl.h" - }; -#else - const std::string transpose_32_src = read_file("ggml-opencl_transpose_32.cl"); -#endif - backend_ctx->program_transpose_32 = build_program_from_source(context, device, transpose_32_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_transpose_32 = clCreateKernel(backend_ctx->program_transpose_32, "kernel_transpose_32", &err), err)); - -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string transpose_32_16_src { - #include "ggml-opencl_transpose_32_16.cl.h" - }; -#else - const std::string transpose_32_16_src = read_file("ggml-opencl_transpose_32_16.cl"); -#endif - backend_ctx->program_transpose_32_16 = build_program_from_source(context, device, transpose_32_16_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_transpose_32_16 = clCreateKernel(backend_ctx->program_transpose_32_16, "kernel_transpose_32_16", &err), err)); - -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string transpose_16_src { - #include "ggml-opencl_transpose_16.cl.h" - }; -#else - const std::string transpose_16_src = read_file("ggml-opencl_transpose_16.cl"); -#endif - backend_ctx->program_transpose_16 = build_program_from_source(context, device, transpose_16_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_transpose_16 = clCreateKernel(backend_ctx->program_transpose_16, "kernel_transpose_16", &err), err)); - - // Gemv general - std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable " - " -DSIMDGROUP_WIDTH=" + - std::to_string(backend_ctx->adreno_wave_size); - if (has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; - } -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_CL_gemv_general { - #include "ggml-opencl_gemv_noshuffle_general.cl.h" - }; -#else - const std::string kernel_src_CL_gemv_general = read_file("ggml-opencl_gemv_noshuffle_general.cl"); -#endif - - backend_ctx->program_CL_gemv_general = build_program_from_source( - context, device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general = clCreateKernel(backend_ctx->program_CL_gemv_general, "kernel_gemv_noshuffle", &err), err)); - - // Gemv 2048, 16384 - CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable " - " -DLINE_STRIDE_A=2048 " - " -DBLOCK_STRIDE_A=16384 " - " -DSIMDGROUP_WIDTH=" + - std::to_string(backend_ctx->adreno_wave_size); - if (has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; - } -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_CL_gemv { - #include "ggml-opencl_gemv_noshuffle.cl.h" - }; -#else - const std::string kernel_src_CL_gemv = read_file("ggml-opencl_gemv_noshuffle.cl"); -#endif - - backend_ctx->program_CL_gemv_4096_1_4096 = build_program_from_source( - context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_4096, "kernel_gemv_noshuffle", &err), err)); - - // Gemv 2048, 16384 - CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable " - " -DLINE_STRIDE_A=2048 " - " -DBLOCK_STRIDE_A=16384 " - " -DSIMDGROUP_WIDTH=" + - std::to_string(backend_ctx->adreno_wave_size); - if (has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; - } - - backend_ctx->program_CL_gemv_4096_1_11008 = build_program_from_source( - context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_11008, "kernel_gemv_noshuffle", &err), err)); - - // Gemv 5504, 44032 - CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable " - " -DLINE_STRIDE_A=5504 " - " -DBLOCK_STRIDE_A=44032 " - " -DSIMDGROUP_WIDTH=" + - std::to_string(backend_ctx->adreno_wave_size); - if (has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; - } - - backend_ctx->program_CL_gemv_11008_1_4096 = build_program_from_source( - context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_11008_1_4096, "kernel_gemv_noshuffle", &err), err)); - - // Gemv 16000, 128000 - CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable " - " -DLINE_STRIDE_A=16000 " - " -DBLOCK_STRIDE_A=128000 " - " -DSIMDGROUP_WIDTH=" + - std::to_string(backend_ctx->adreno_wave_size); - if (has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; - } - - backend_ctx->program_CL_gemv_32000_1_4096 = build_program_from_source(context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_32000_1_4096, "kernel_gemv_noshuffle", &err), err)); - - // Gemm -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_CL_gemm { - #include "ggml-opencl_mul_mat_Ab_Bi_8x4.cl.h" - }; -#else - const std::string kernel_src_CL_gemm = read_file("ggml-opencl_mul_mat_Ab_Bi_8x4.cl"); -#endif - backend_ctx->program_CL_gemm = build_program_from_source(context, device, kernel_src_CL_gemm.c_str(), compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err)); - - // TODO: fixme: these sizes are hardcoded for now. - // they should be allocated based on the model's size - // and the device's max alloc size // Allocate intermediate buffers and images size_t required_A_q_d_bytes = 311164928; size_t required_A_s_d_bytes = 38895616; @@ -959,10 +1837,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { CL_CHECK((backend_ctx->B_d_max = clCreateBuffer(context, 0, max_B_d_bytes, NULL, &err), err)); #endif // GGML_OPENCL_USE_ADRENO_KERNELS - // For now we support a single devices - ggml_backend_opencl_n_devices = 1; - - return backend_ctx; + dev_ctx->backend_ctx = backend_ctx.release(); + return dev_ctx->backend_ctx; } static void ggml_cl2_free(void) { @@ -1024,7 +1900,7 @@ static void ggml_cl2_free(void) { info.cmd_complete_duration_ns/1.e6f, info.cmd_total_duration_ns/1.e6f, info.global_size[0], info.global_size[1], info.global_size[2], - info.local_size[0], info.local_size[2], info.local_size[2], + info.local_size[0], info.local_size[1], info.local_size[2], info.output_size[0], info.output_size[1], info.output_size[2], info.output_size[3]); } fclose(fperf); @@ -1164,13 +2040,54 @@ static bool ggml_backend_opencl_cpy_tensor_async(ggml_backend_t backend, const g } static void ggml_backend_opencl_synchronize(ggml_backend_t backend) { - GGML_UNUSED(backend); + auto * backend_ctx = static_cast(backend->context); + + cl_event evt; + CL_CHECK(clEnqueueBarrierWithWaitList(backend_ctx->queue, 0, nullptr, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseEvent(evt)); +} + +// Syncronizes the 'backend_ctx's device with others so that commands +// enqueued to it won't start until commands in the other devices have +// completed. +static void sync_with_other_backends(ggml_backend_opencl_context * backend_ctx) { + if (g_ggml_backend_opencl_devices.size() < 2) + return; // No other devices to synchronize with. + + std::vector events; + events.reserve(g_ggml_backend_opencl_devices.size()); + + for (ggml_backend_device & backend_dev : g_ggml_backend_opencl_devices) { + auto * other_backend_ctx = ggml_cl2_init(&backend_dev); + if (backend_ctx != other_backend_ctx) { + cl_event ev; + CL_CHECK(clEnqueueMarkerWithWaitList(other_backend_ctx->queue, 0, nullptr, &ev)); + CL_CHECK(clFlush(other_backend_ctx->queue)); + events.push_back(ev); + } + } + + CL_CHECK(clEnqueueBarrierWithWaitList(backend_ctx->queue, events.size(), events.data(), nullptr)); + for (auto ev : events) { + CL_CHECK(clReleaseEvent(ev)); + } +} + +static void sync_with_other_backends(ggml_backend_t backend) { + auto * backend_ctx = static_cast(backend->context); + sync_with_other_backends(backend_ctx); } static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; + // NOTE: this may oversynchronize by synchronizing with + // backends/devices which don't compute 'cgraph's + // dependencies. + sync_with_other_backends(backend); + if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { continue; } @@ -1232,6 +2149,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_ADD: case GGML_OP_SCALE: case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_SUB: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { @@ -1240,6 +2159,11 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_GELU_QUICK: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + case GGML_UNARY_OP_SIGMOID: + return ggml_is_contiguous(op->src[0]); + case GGML_UNARY_OP_TANH: + return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || + (op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16); default: return false; } @@ -1249,16 +2173,36 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_NORM: case GGML_OP_RMS_NORM: return true; + case GGML_OP_REPEAT: + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded + case GGML_OP_PAD: + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && + op->src[0]->ne[3] == 1 && op->ne[3] == 1; + case GGML_OP_UPSCALE: + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + case GGML_OP_CONCAT: + return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + case GGML_OP_TIMESTEP_EMBEDDING: + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + case GGML_OP_GROUP_NORM: + return ggml_is_contiguous(op->src[0]); case GGML_OP_MUL_MAT: if (op->src[0]->type == GGML_TYPE_F16) { return true; } else if (op->src[0]->type == GGML_TYPE_F32) { - return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); + return op->src[1]->type == GGML_TYPE_F32; } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q6_K) { return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); } return false; + case GGML_OP_MUL_MAT_ID: + if (op->src[0]->type == GGML_TYPE_Q4_0) { + if (op->src[1]->type == GGML_TYPE_F32) { + return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); + } + } + return false; case GGML_OP_RESHAPE: case GGML_OP_VIEW: case GGML_OP_PERMUTE: @@ -1288,6 +2232,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te } case GGML_OP_IM2COL: return true; + case GGML_OP_ARGSORT: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_SUM_ROWS: + return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); default: return false; } @@ -1307,7 +2255,7 @@ static ggml_backend_i ggml_backend_opencl_i = { /* .set_tensor_async = */ NULL, /* ggml_backend_opencl_set_tensor_async */ /* .get_tensor_async = */ NULL, /* ggml_backend_opencl_get_tensor_async */ /* .cpy_tensor_async = */ NULL, /* ggml_backend_opencl_cpy_tensor_async */ - /* .synchronize = */ NULL, /* ggml_backend_opencl_synchronize */ + /* .synchronize = */ ggml_backend_opencl_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, /* .graph_plan_update = */ NULL, @@ -1495,8 +2443,15 @@ static enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buff // The optimized gemm and gemv kernels are used for large matrices without batch. // tensor is the quantized weights matrix. -inline bool use_adreno_kernels(const ggml_tensor *tensor) { - return tensor->ne[0] >= 512 && tensor->ne[1] >= 512 && +inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { + int64_t threshold_ne0 = 512; + int64_t threshold_ne1 = 512; + if (!backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) && + backend_ctx->adreno_cl_compiler_version.type != DX) { + threshold_ne0 = 128; + threshold_ne1 = 128; + } + return tensor->ne[0] >= threshold_ne0 && tensor->ne[1] >= threshold_ne1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; } @@ -1554,15 +2509,16 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, // The original tensor memory is divided into scales and quants, i.e., // we first store scales, then quants. // Create subbuffer for scales. - region.origin = extra_orig->offset + tensor->view_offs + offset; + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); region.size = size_d; extra->d = clCreateSubBuffer( extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); CL_CHECK(err); + auto previous_origin = region.origin; // Create subbuffer for quants. - region.origin = extra_orig->offset + tensor->view_offs + offset + size_d; + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); region.size = size_q; extra->q = clCreateSubBuffer( extra_orig->data_device, CL_MEM_READ_WRITE, @@ -1574,7 +2530,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0; // The optimized kernels need weights in natural order, so unshuffle. - if (use_adreno_kernels(tensor)) { + if (use_adreno_kernels(backend_ctx, tensor)) { kernel = backend_ctx->kernel_convert_block_q4_0_noshuffle; } #else @@ -1598,7 +2554,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, #ifdef GGML_OPENCL_USE_ADRENO_KERNELS // Only do transpose for large, non batched matrix // TODO: use preallocated images instead of sub-buffer then image - if (use_adreno_kernels(tensor)) { + if (use_adreno_kernels(backend_ctx, tensor)) { // <----------------------------------------------------------------------------------> // // start transpose // <----------------------------------------------------------------------------------> // @@ -1767,8 +2723,8 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, cl_context context = backend_ctx->context; cl_command_queue queue = backend_ctx->queue; - // Make sure all previously submitted commands are finished. - CL_CHECK(clFinish(queue)); + // Make sure all previously submitted commands in other devices are finished. + sync_with_other_backends(backend_ctx); #ifdef GGML_OPENCL_SOA_Q // In end-to-end runs, get_tensor is usually used to get back the logits, @@ -1872,13 +2828,8 @@ static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_b } static size_t ggml_backend_opencl_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) { - // FIXME: not thread safe, device may not be initialized yet - static cl_uint alignment = -1; - if (alignment == (cl_uint)-1) { - ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device); - alignment = backend_ctx->alignment; - } - return alignment; + ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device); + return backend_ctx->alignment; } static size_t ggml_backend_opencl_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) { @@ -1905,16 +2856,6 @@ static ggml_backend_buffer_type_i ggml_backend_opencl_buffer_type_interface = { /* .is_host = */ NULL, }; -ggml_backend_buffer_type_t ggml_backend_opencl_buffer_type() { - static ggml_backend_buffer_type buffer_type = { - /* .iface = */ ggml_backend_opencl_buffer_type_interface, - /* .device = */ &g_ggml_backend_opencl_device, - /* .context = */ nullptr, - }; - - return &buffer_type; -} - // // backend device // @@ -1972,9 +2913,15 @@ static ggml_backend_t ggml_backend_opencl_device_init(ggml_backend_dev_t dev, co } static ggml_backend_buffer_type_t ggml_backend_opencl_device_get_buffer_type(ggml_backend_dev_t dev) { - return ggml_backend_opencl_buffer_type(); + auto * dev_ctx = static_cast(dev->context); - GGML_UNUSED(dev); + dev_ctx->buffer_type = ggml_backend_buffer_type{ + /* .iface = */ ggml_backend_opencl_buffer_type_interface, + /* .device = */ dev, + /* .context = */ nullptr, + }; + + return &dev_ctx->buffer_type; } static ggml_backend_buffer_t ggml_backend_opencl_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { @@ -1990,12 +2937,21 @@ static bool ggml_backend_opencl_device_supports_op(ggml_backend_dev_t dev, const } static bool ggml_backend_opencl_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { - return buft->iface.get_name == ggml_backend_opencl_buffer_type_get_name; + // Check 'dev' and 'buffer_type' are not objects belonging to this backend. + if (dev->iface.get_name != ggml_backend_opencl_device_get_name || + buft->iface.get_name != ggml_backend_opencl_buffer_type_get_name) { + return false; + } - GGML_UNUSED(dev); + // Check cl_context is the same. clEnqueue* commands may not use + // buffers from another cl_context. + ggml_backend_opencl_context * backend_ctx0 = ggml_cl2_init(dev); + ggml_backend_opencl_context * backend_ctx1 = ggml_cl2_init(buft->device); + return backend_ctx0->context == backend_ctx1->context; } -static struct ggml_backend_device_i ggml_backend_opencl_device_i = { +namespace /* anonymous */ { +struct ggml_backend_device_i ggml_backend_opencl_device_i = { /* .get_name = */ ggml_backend_opencl_device_get_name, /* .get_description = */ ggml_backend_opencl_device_get_description, /* .get_memory = */ ggml_backend_opencl_device_get_memory, @@ -2012,6 +2968,7 @@ static struct ggml_backend_device_i ggml_backend_opencl_device_i = { /* .event_free = */ NULL, /* .event_synchronize = */ NULL, }; +} // Backend registry @@ -2022,15 +2979,15 @@ static const char * ggml_backend_opencl_reg_get_name(ggml_backend_reg_t reg) { } static size_t ggml_backend_opencl_reg_device_count(ggml_backend_reg_t reg) { - return ggml_backend_opencl_n_devices; + return g_ggml_backend_opencl_devices.size(); GGML_UNUSED(reg); } static ggml_backend_dev_t ggml_backend_opencl_reg_device_get(ggml_backend_reg_t reg, size_t index) { - GGML_ASSERT(index == 0); + GGML_ASSERT(index < ggml_backend_opencl_reg_device_count(reg)); - return &g_ggml_backend_opencl_device; + return &g_ggml_backend_opencl_devices[index]; GGML_UNUSED(reg); GGML_UNUSED(index); @@ -2044,27 +3001,23 @@ static struct ggml_backend_reg_i ggml_backend_opencl_reg_i = { }; ggml_backend_reg_t ggml_backend_opencl_reg(void) { - // TODO: make this thread-safe somehow? + static std::mutex mutex; static ggml_backend_reg reg; static bool initialized = false; + std::lock_guard lock(mutex); - if (!initialized) { - reg = ggml_backend_reg { - /* .api_version = */ GGML_BACKEND_API_VERSION, - /* .iface = */ ggml_backend_opencl_reg_i, - /* .context = */ NULL, - }; - - g_ggml_backend_opencl_device = ggml_backend_device { - /* .iface = */ ggml_backend_opencl_device_i, - /* .reg = */ ®, - /* .context = */ &g_ggml_ctx_dev_main, - }; - - ggml_cl2_init(&g_ggml_backend_opencl_device); - - initialized = true; + if (initialized) { + return ® } + initialized = true; + + g_ggml_backend_opencl_devices = ggml_opencl_probe_devices(®); + + reg = ggml_backend_reg{ + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_opencl_reg_i, + /* .context = */ NULL, + }; return ® } @@ -2438,14 +3391,19 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const size_t global_work_size[] = {(size_t)n, 1, 1}; size_t local_work_size[] = {64, 1, 1}; + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + } + #ifdef GGML_OPENCL_PROFILING cl_event evt; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); g_profiling_info.emplace_back(); - populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst); #else - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); #endif } else { unsigned int nth = MIN(64, ne0); @@ -2568,6 +3526,261 @@ static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3)); } + if (bcast_row) { + int n = ggml_nelements(dst)/4; + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + } + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); +#endif + } else { + unsigned int nth = MIN(64, ne0); + size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {nth, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } +} + +static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const cl_ulong nb10 = src1->nb[0]; + const cl_ulong nb11 = src1->nb[1]; + const cl_ulong nb12 = src1->nb[2]; + const cl_ulong nb13 = src1->nb[3]; + + const int ne0 = dst->ne[0]; + + const cl_ulong nb0 = dst->nb[0]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + bool bcast_row = false; + cl_kernel kernel; + + if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + GGML_ASSERT(ggml_is_contiguous(src0)); + + // src1 is a row + GGML_ASSERT(ne11 == 1); + + bcast_row = true; + int ne = ne00 / 4; + kernel = backend_ctx->kernel_div_row; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne)); + } else { + kernel = backend_ctx->kernel_div; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3)); + } + + if (bcast_row) { + int n = ggml_nelements(dst)/4; + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } else { + unsigned int nth = MIN(64, ne0); + size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {nth, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif + } +} + +static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const cl_ulong nb10 = src1->nb[0]; + const cl_ulong nb11 = src1->nb[1]; + const cl_ulong nb12 = src1->nb[2]; + const cl_ulong nb13 = src1->nb[3]; + + const int ne0 = dst->ne[0]; + + const cl_ulong nb0 = dst->nb[0]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + bool bcast_row = false; + cl_kernel kernel; + + if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + GGML_ASSERT(ggml_is_contiguous(src0)); + + // src1 is a row + GGML_ASSERT(ne11 == 1); + + bcast_row = true; + int ne = ne00 / 4; + kernel = backend_ctx->kernel_sub_row; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne)); + } else { + kernel = backend_ctx->kernel_sub; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3)); + } + if (bcast_row) { int n = ggml_nelements(dst)/4; size_t global_work_size[] = {(size_t)n, 1, 1}; @@ -2729,14 +3942,19 @@ static void ggml_cl_silu(ggml_backend_t backend, const ggml_tensor * src0, const size_t global_work_size[] = {(size_t)n, 1, 1}; size_t local_work_size[] = {64, 1, 1}; + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + } + #ifdef GGML_OPENCL_PROFILING cl_event evt; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); g_profiling_info.emplace_back(); - populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst); #else - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); #endif } @@ -2769,14 +3987,71 @@ static void ggml_cl_relu(ggml_backend_t backend, const ggml_tensor * src0, const size_t global_work_size[] = {(size_t)n, 1, 1}; size_t local_work_size[] = {64, 1, 1}; + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + } + #ifdef GGML_OPENCL_PROFILING cl_event evt; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); g_profiling_info.emplace_back(); - populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst); #else - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_sigmoid(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_sigmoid_f32; + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + kernel = backend_ctx->kernel_sigmoid_f16; + } else { + GGML_ASSERT(false && "Unsupported data types for sigmoid (input and output must be both f32 or f16)"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + const int64_t n = ggml_nelements(dst); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + } + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); #endif } @@ -2816,14 +4091,19 @@ static void ggml_cl_clamp(ggml_backend_t backend, const ggml_tensor * src0, cons size_t global_work_size[] = {(size_t)n, 1, 1}; size_t local_work_size[] = {64, 1, 1}; + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + } + #ifdef GGML_OPENCL_PROFILING cl_event evt; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); g_profiling_info.emplace_back(); - populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst); #else - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); #endif } @@ -2899,8 +4179,8 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; cl_command_queue queue = backend_ctx->queue; - ggml_backend_opencl_device_context * dev_ctx = - (ggml_backend_opencl_device_context *)backend->device->context; + //ggml_backend_opencl_device_context * dev_ctx = + // (ggml_backend_opencl_device_context *)backend->device->context; ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; @@ -2931,13 +4211,20 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c // Note, this kernel declares local memory in kernel args and the size // depends on subgroup size. - // Retrieve subgroup size. // Note, this requires OpenCL 2.1 and above + // For now we use fixed subgroup size to simplify support for OpenCL 2.0. size_t sgs; - CL_CHECK(clGetKernelSubGroupInfo(kernel, dev_ctx->device, - CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE, - sizeof(local_work_size), local_work_size, - sizeof(size_t), &sgs, NULL)); + //CL_CHECK(clGetKernelSubGroupInfo(kernel, dev_ctx->device, + // CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE, + // sizeof(local_work_size), local_work_size, + // sizeof(size_t), &sgs, NULL)); + if (backend_ctx->gpu_family == ADRENO) { + sgs = 64; + } else if (backend_ctx->gpu_family == INTEL) { + sgs = 32; + } else { + GGML_ASSERT(false && "Unsupported GPU"); + } CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); @@ -2965,6 +4252,595 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c #endif } +static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + int32_t n_groups = ((const int32_t *) dst->op_params)[0]; + int32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + n_groups - 1) / n_groups); + float eps = ((const float *) dst->op_params)[1]; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne = ne00*ne01*ne02; + + cl_kernel kernel = backend_ctx->kernel_group_norm; + + size_t sgs = 64; + if (backend_ctx->gpu_family == ADRENO) { + sgs = 64; + } else if (backend_ctx->gpu_family == INTEL) { + sgs = 32; + } else { + GGML_ASSERT(false && "Unsupported GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &group_size)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float), &eps)); + + size_t global_work_size[] = {(size_t)n_groups*sgs, 1, 1}; + size_t local_work_size[] = {(size_t)sgs, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_tanh(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0_abs = extra0->offset + src0->view_offs; + cl_ulong offsetd_abs = extrad->offset + dst->view_offs; + + cl_kernel kernel; + if (dst->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_tanh_f32_nd; + } else if (dst->type == GGML_TYPE_F16) { + kernel = backend_ctx->kernel_tanh_f16_nd; + } else { + GGML_ASSERT(false && "Unsupported type for ggml_cl_tanh"); + } + GGML_ASSERT(kernel != nullptr); + + const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; const int ne02 = src0->ne[2]; const int ne03 = src0->ne[3]; + const cl_ulong nb00 = src0->nb[0]; const cl_ulong nb01 = src0->nb[1]; const cl_ulong nb02 = src0->nb[2]; const cl_ulong nb03 = src0->nb[3]; + + const int ne10 = dst->ne[0]; const int ne11 = dst->ne[1]; const int ne12 = dst->ne[2]; const int ne13 = dst->ne[3]; + const cl_ulong nb10 = dst->nb[0]; const cl_ulong nb11 = dst->nb[1]; const cl_ulong nb12 = dst->nb[2]; const cl_ulong nb13 = dst->nb[3]; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs)); + + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03)); + + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13)); + + size_t global_work_size[3]; + if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements + return; + } + global_work_size[0] = (size_t)ne10; + global_work_size[1] = (size_t)ne11; + global_work_size[2] = (size_t)ne12; + + size_t lws0 = 16, lws1 = 4, lws2 = 1; + if (ne10 < 16) lws0 = ne10; + if (ne11 < 4) lws1 = ne11; + if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1; + + while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2; + while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2; + while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2; + + + size_t local_work_size[] = {lws0, lws1, lws2}; + + size_t* local_work_size_ptr = local_work_size; + if (!backend_ctx->non_uniform_workgroups) { + if (global_work_size[0] % local_work_size[0] != 0 || + global_work_size[1] % local_work_size[1] != 0 || + global_work_size[2] % local_work_size[2] != 0) { + local_work_size_ptr = NULL; + } + } + if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return; + + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr ? local_work_size : (size_t[3]){0,0,0}, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_repeat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1_shape_def, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_ASSERT(dst->type == src0->type); + + UNUSED(src1_shape_def); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + if (backend_ctx->kernel_repeat == nullptr) { + GGML_LOG_WARN("%s: repeat kernel not available, skipping OpenCL execution.\n", __func__); + return; + } + + ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong off_src0 = extra_src0->offset + src0->view_offs; + cl_ulong off_dst = extra_dst->offset + dst->view_offs; + + const int src0_ne0 = src0->ne[0]; const int src0_ne1 = src0->ne[1]; const int src0_ne2 = src0->ne[2]; const int src0_ne3 = src0->ne[3]; + const cl_ulong src0_nb0 = src0->nb[0]; const cl_ulong src0_nb1 = src0->nb[1]; const cl_ulong src0_nb2 = src0->nb[2]; const cl_ulong src0_nb3 = src0->nb[3]; + + const int dst_ne0 = dst->ne[0]; const int dst_ne1 = dst->ne[1]; const int dst_ne2 = dst->ne[2]; const int dst_ne3 = dst->ne[3]; + const cl_ulong dst_nb0 = dst->nb[0]; const cl_ulong dst_nb1 = dst->nb[1]; const cl_ulong dst_nb2 = dst->nb[2]; const cl_ulong dst_nb3 = dst->nb[3]; + + cl_kernel kernel = backend_ctx->kernel_repeat; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra_dst->data_device)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_ulong), &off_src0)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &src0_ne0)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &src0_ne1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &src0_ne2)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &src0_ne3)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &src0_nb0)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &src0_nb1)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &src0_nb2)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &src0_nb3)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &dst_ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &dst_ne1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &dst_ne2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &dst_ne3)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &dst_nb0)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &dst_nb1)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &dst_nb2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &dst_nb3)); + + size_t gws0 = dst_ne1 > 0 ? (size_t)dst_ne1 : 1; + size_t gws1 = dst_ne2 > 0 ? (size_t)dst_ne2 : 1; + size_t gws2 = dst_ne3 > 0 ? (size_t)dst_ne3 : 1; + + size_t global_work_size[] = { gws0, gws1, gws2 }; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, NULL, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, (size_t[3]){0,0,0}, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, NULL, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + if (backend_ctx->kernel_pad == nullptr) { + GGML_LOG_WARN("%s: pad kernel not available, skipping OpenCL execution.\n", __func__); + return; + } + + ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong off_src0 = extra_src0->offset + src0->view_offs; + cl_ulong off_dst = extra_dst->offset + dst->view_offs; + + const int s_ne0 = src0->ne[0]; + const int s_ne1 = src0->ne[1]; + const int s_ne2 = src0->ne[2]; + + const int d_ne0 = dst->ne[0]; + const int d_ne1 = dst->ne[1]; + const int d_ne2 = dst->ne[2]; + + cl_kernel kernel = backend_ctx->kernel_pad; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &d_ne0)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &d_ne1)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &d_ne2)); + + size_t lws0 = 64; + size_t gws0 = (( (size_t)d_ne0 + lws0 - 1 ) / lws0) * lws0; + + size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2 }; + size_t local_work_size[] = { lws0, 1, 1 }; + + size_t * local_work_size_ptr = local_work_size; + if (d_ne0 % lws0 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr ? local_work_size : (size_t[3]){0,0,0}, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + const ggml_scale_mode mode = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0); + cl_kernel kernel = nullptr; + + if (mode == GGML_SCALE_MODE_NEAREST) { + kernel = backend_ctx->kernel_upscale; + if (kernel == nullptr) { + GGML_LOG_WARN("%s: nearest upscale kernel not available, skipping OpenCL execution.\n", __func__); + return; + } + } else if (mode == GGML_SCALE_MODE_BILINEAR) { + kernel = backend_ctx->kernel_upscale_bilinear; + if (kernel == nullptr) { + GGML_LOG_WARN("%s: bilinear upscale kernel not available, skipping OpenCL execution.\n", __func__); + return; + } + } else { + GGML_LOG_WARN("%s: unsupported upscale mode %d, skipping OpenCL execution.\n", __func__, mode); + return; + } + + ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong off_src0 = extra_src0->offset + src0->view_offs; + cl_ulong off_dst = extra_dst->offset + dst->view_offs; + + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const int ne00_src = src0->ne[0]; + const int ne01_src = src0->ne[1]; + + const int ne10_dst = dst->ne[0]; + const int ne11_dst = dst->ne[1]; + const int ne12_dst = dst->ne[2]; + const int ne13_dst = dst->ne[3]; + + const float sf0 = (float)dst->ne[0] / src0->ne[0]; + const float sf1 = (float)dst->ne[1] / src0->ne[1]; + const float sf2 = (float)dst->ne[2] / src0->ne[2]; + const float sf3 = (float)dst->ne[3] / src0->ne[3]; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb03)); + + if (mode == GGML_SCALE_MODE_NEAREST) { + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne10_dst)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11_dst)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12_dst)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13_dst)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &sf0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float), &sf1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(float), &sf2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf3)); + } else if (mode == GGML_SCALE_MODE_BILINEAR) { + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00_src)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01_src)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10_dst)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11_dst)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12_dst)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13_dst)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(float), &sf0)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf1)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(float), &sf2)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(float), &sf3)); + } + + + size_t dst_total_elements = (size_t)ne10_dst * ne11_dst * ne12_dst * ne13_dst; + if (dst_total_elements == 0) { + return; + } + size_t global_work_size[] = { dst_total_elements, 1, 1 }; + size_t local_work_size_pref = 256; + size_t local_work_size[] = { MIN(local_work_size_pref, dst_total_elements), 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (dst_total_elements % local_work_size[0] != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 1, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + size_t profiling_gws[3] = {global_work_size[0], 1, 1}; + size_t profiling_lws[3] = {local_work_size_ptr ? local_work_size[0] : 0, 1, 1}; + populateProfilingInfo(g_profiling_info.back(), evt, kernel, profiling_gws, profiling_lws, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 1, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_concat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + if (backend_ctx->kernel_concat_f32_contiguous == nullptr || backend_ctx->kernel_concat_f32_non_contiguous == nullptr) { + GGML_LOG_WARN("%s: concat kernels not available, skipping OpenCL execution.\n", __func__); + return; + } + + ggml_tensor_extra_cl * extra0_cl = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1_cl = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad_cl = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong off_src0 = extra0_cl->offset + src0->view_offs; + cl_ulong off_src1 = extra1_cl->offset + src1->view_offs; + cl_ulong off_dst = extrad_cl->offset + dst->view_offs; + + const int32_t dim = ((const int32_t *) dst->op_params)[0]; + GGML_ASSERT(dim >= 0 && dim <= 3); + + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { + if (dim == 3) { + + size_t nbytes_src0 = ggml_nbytes(src0); + size_t nbytes_src1 = ggml_nbytes(src1); + + CL_CHECK(clEnqueueCopyBuffer(queue, extra0_cl->data_device, extrad_cl->data_device, + off_src0, off_dst, nbytes_src0, 0, NULL, NULL)); + CL_CHECK(clEnqueueCopyBuffer(queue, extra1_cl->data_device, extrad_cl->data_device, + off_src1, off_dst + nbytes_src0, nbytes_src1, 0, NULL, NULL)); + } else { + + cl_kernel kernel = backend_ctx->kernel_concat_f32_contiguous; + size_t global_work_size[3]; + + for (int i3 = 0; i3 < dst->ne[3]; ++i3) { + cl_ulong current_off_src0 = off_src0 + (i3 * src0->nb[3]); + cl_ulong current_off_src1 = off_src1 + (i3 * src1->nb[3]); + cl_ulong current_off_dst = off_dst + (i3 * dst->nb[3]); + + int d_ne00 = src0->ne[0]; int d_ne01 = src0->ne[1]; int d_ne02 = src0->ne[2]; + int d_ne10 = src1->ne[0]; int d_ne11 = src1->ne[1]; int d_ne12 = src1->ne[2]; + int d_ne0 = dst->ne[0]; int d_ne1 = dst->ne[1]; int d_ne2 = dst->ne[2]; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_cl->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), ¤t_off_src0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1_cl->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), ¤t_off_src1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad_cl->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), ¤t_off_dst)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &d_ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &d_ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &d_ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &d_ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &d_ne11)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &d_ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &d_ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &d_ne1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &d_ne2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &dim)); + + global_work_size[0] = d_ne0; + global_work_size[1] = d_ne1; + global_work_size[2] = d_ne2; + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, NULL, 0, NULL, NULL)); + } + } + } else { + cl_kernel kernel = backend_ctx->kernel_concat_f32_non_contiguous; + + long ne00 = src0->ne[0], ne01 = src0->ne[1], ne02 = src0->ne[2], ne03 = src0->ne[3]; + cl_ulong nb00 = src0->nb[0], nb01 = src0->nb[1], nb02 = src0->nb[2], nb03 = src0->nb[3]; + + cl_ulong nb10 = src1->nb[0], nb11 = src1->nb[1], nb12 = src1->nb[2], nb13 = src1->nb[3]; + + long d_ne0 = dst->ne[0], d_ne1 = dst->ne[1], d_ne2 = dst->ne[2], d_ne3 = dst->ne[3]; + cl_ulong d_nb0 = dst->nb[0], d_nb1 = dst->nb[1], d_nb2 = dst->nb[2], d_nb3 = dst->nb[3]; + + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_cl->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1_cl->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_src1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad_cl->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &off_dst)); + + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(long), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(long), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(long), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(long), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03)); + + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13)); + + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(long), &d_ne0)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(long), &d_ne1)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(long), &d_ne2)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(long), &d_ne3)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &d_nb0)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &d_nb1)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &d_nb2)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(cl_ulong), &d_nb3)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &dim)); + + size_t global_work_size_nc[] = { d_ne1 > 0 ? (size_t)d_ne1 : 1, + d_ne2 > 0 ? (size_t)d_ne2 : 1, + d_ne3 > 0 ? (size_t)d_ne3 : 1 }; + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size_nc, NULL, 0, NULL, NULL)); + } +} + +static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + if (backend_ctx->kernel_timestep_embedding == nullptr) { + GGML_LOG_WARN("%s: timestep_embedding kernel not available, skipping OpenCL execution.\n", __func__); + return; + } + + ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong off_src0 = extra_src0->offset + src0->view_offs; + cl_ulong off_dst = extra_dst->offset + dst->view_offs; + + const int logical_dim = dst->op_params[0]; + const int max_period = dst->op_params[1]; + const int dst_nb1_bytes = dst->nb[1]; + + cl_kernel kernel = backend_ctx->kernel_timestep_embedding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &dst_nb1_bytes)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &logical_dim)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &max_period)); + + size_t gws0 = (size_t)(((logical_dim + 1) / 2) + 1); + + size_t gws1 = (size_t)src0->ne[0]; + + size_t global_work_size[] = {gws0, gws1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 2, NULL, global_work_size, NULL, 0, NULL, &evt)); // Pass 2 for 2D problem + + g_profiling_info.emplace_back(); + size_t profiling_gws[3] = {global_work_size[0], global_work_size[1], 1}; + size_t profiling_lws[3] = {0,0,0}; // Reflects NULL LWS + populateProfilingInfo(g_profiling_info.back(), evt, kernel, profiling_gws, profiling_lws, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 2, NULL, global_work_size, NULL, 0, NULL, NULL)); // Pass 2 for 2D problem +#endif +} + static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -3030,7 +4906,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co #ifdef GGML_OPENCL_USE_ADRENO_KERNELS cl_context context = backend_ctx->context; - if (ne01 && ne1 && use_adreno_kernels(src0)) { + if (ne01 && ne1 && use_adreno_kernels(backend_ctx, src0)) { // init CL objects // <--------------------------------------------> // @@ -3685,6 +5561,136 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co } } +static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const ggml_tensor * src2 = dst->src[2]; + GGML_ASSERT(src2); + GGML_ASSERT(src2->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offset2 = extra2->offset + src2->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + +#ifdef GGML_OPENCL_SOA_Q + ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; +#endif + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb02 = src0->nb[2]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const cl_ulong nb11 = src1->nb[1]; + const cl_ulong nb12 = src1->nb[2]; + + const int ne20 = src2->ne[0]; + const int ne21 = src2->ne[1]; + + const cl_ulong nb21 = src2->nb[1]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + + const int r2 = ne12/ne02; + const int r3 = ne13/ne03; + const int dst_rows = ne20*ne21; // ne20 = n_used_experts, ne21 = n_rows + + GGML_ASSERT(ne00 == ne10); + + int sgs = 32; // subgroup size + int nsg = 1; // number of subgroups + int nrows = 1; // number of row in src1 + int ndst = 4; // number of values produced by each subgroup + + cl_kernel kernel; + + // subgroup mat vec + switch (src0->type) { + case GGML_TYPE_Q4_0: { + kernel = backend_ctx->kernel_mul_mv_id_q4_0_f32_8x_flat; + + if (backend_ctx->gpu_family == INTEL) { + sgs = 16; + nsg = 1; + ndst = 8; + } else if (backend_ctx->gpu_family == ADRENO) { + sgs = 64; + nsg = 1; + ndst = 8; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne20)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &ne21)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb21)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &r3)); + + break; + } + default: + GGML_ASSERT(false && "not implemented");; + } + + int _ne1 = 1; + int ne123 = dst_rows; + + size_t global_work_size[] = {(size_t)(ne01+ndst*nsg-1)/(ndst*nsg)*sgs, (size_t)(_ne1+nrows-1)/nrows*nsg, (size_t)ne123}; + size_t local_work_size[] = {(size_t)sgs, (size_t)nsg, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -3719,14 +5725,19 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons size_t global_work_size[] = {(size_t)n, 1, 1}; size_t local_work_size[] = {64, 1, 1}; + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + } + #ifdef GGML_OPENCL_PROFILING cl_event evt; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); g_profiling_info.emplace_back(); - populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst); #else - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); #endif } @@ -3907,14 +5918,19 @@ static void ggml_cl_diag_mask_inf(ggml_backend_t backend, const ggml_tensor * sr size_t global_work_size[] = {(size_t)ne00, (size_t)ne01, (size_t)ne02}; size_t local_work_size[] = {64, 1, 1}; + size_t * local_work_size_ptr = local_work_size; + if (ne00 % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + } + #ifdef GGML_OPENCL_PROFILING cl_event evt; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, &evt)); g_profiling_info.emplace_back(); - populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size_ptr, dst); #else - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size_ptr, 0, NULL, NULL)); #endif } } @@ -4304,6 +6320,124 @@ static void ggml_cl_im2col(ggml_backend_t backend, const ggml_tensor * src0, con #endif } +static void ggml_cl_argsort(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_UNUSED(src1); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(src0)); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int nrows = ggml_nrows(src0); + + int ne00_padded = 1; + while (ne00_padded < ne00) { + ne00_padded *= 2; + } + + int order = (enum ggml_sort_order) dst->op_params[0]; + + cl_kernel kernel = backend_ctx->kernel_argsort_f32_i32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne00_padded)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &order)); + CL_CHECK(clSetKernelArg(kernel, 7, ne00_padded*sizeof(int), NULL)); + + size_t global_work_size[] = {(size_t)ne00_padded, (size_t)nrows, (size_t)1}; + size_t local_work_size[] = {(size_t)ne00_padded, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + +static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_UNUSED(src1); + + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); + GGML_ASSERT(ggml_is_contiguous(src0)); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; + + cl_kernel kernel = backend_ctx->kernel_sum_rows_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3)); + + size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)64, 1, 1}; + +#ifdef GGML_OPENCL_PROFILING + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + + g_profiling_info.emplace_back(); + populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst); +#else + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); +#endif +} + //------------------------------------------------------------------------------ // Op offloading //------------------------------------------------------------------------------ @@ -4344,8 +6478,6 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor if (!any_on_device) { return false; } - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); func = ggml_cl_add; break; case GGML_OP_MUL: @@ -4354,6 +6486,18 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_mul; break; + case GGML_OP_DIV: + if (!any_on_device) { + return false; + } + func = ggml_cl_div; + break; + case GGML_OP_SUB: + if (!any_on_device) { + return false; + } + func = ggml_cl_sub; + break; case GGML_OP_UNARY: switch (ggml_get_unary_op(tensor)) { case GGML_UNARY_OP_GELU: @@ -4380,6 +6524,18 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_relu; break; + case GGML_UNARY_OP_SIGMOID: + if (!any_on_device) { + return false; + } + func = ggml_cl_sigmoid; + break; + case GGML_UNARY_OP_TANH: + if (!any_on_device) { + return false; + } + func = ggml_cl_tanh; + break; default: return false; } break; @@ -4401,12 +6557,54 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_rms_norm; break; + case GGML_OP_GROUP_NORM: + if (!any_on_device) { + return false; + } + func = ggml_cl_group_norm; + break; + case GGML_OP_REPEAT: + if (!any_on_device) { + return false; + } + func = ggml_cl_repeat; + break; + case GGML_OP_PAD: + if (!any_on_device) { + return false; + } + ggml_cl_pad(backend, tensor->src[0], tensor); + return true; + case GGML_OP_UPSCALE: + if (!any_on_device) { + return false; + } + ggml_cl_upscale(backend, tensor->src[0], tensor); + return true; + case GGML_OP_CONCAT: + if (!any_on_device) { + return false; + } + func = ggml_cl_concat; + break; + case GGML_OP_TIMESTEP_EMBEDDING: + if (!any_on_device) { + return false; + } + ggml_cl_timestep_embedding(backend, tensor->src[0], tensor); + return true; case GGML_OP_MUL_MAT: if (!any_on_device && !ggml_cl_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) { return false; } func = ggml_cl_mul_mat; break; + case GGML_OP_MUL_MAT_ID: + if (!any_on_device) { + return false; + } + func = ggml_cl_mul_mat_id; + break; case GGML_OP_SCALE: if (!any_on_device) { return false; @@ -4446,6 +6644,18 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_im2col; break; + case GGML_OP_ARGSORT: + if (!any_on_device) { + return false; + } + func = ggml_cl_argsort; + break; + case GGML_OP_SUM_ROWS: + if (!any_on_device) { + return false; + } + func = ggml_cl_sum_rows; + break; default: return false; } diff --git a/ggml/src/ggml-opencl/kernels/add.cl b/ggml/src/ggml-opencl/kernels/add.cl new file mode 100644 index 000000000..f73f3c013 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/add.cl @@ -0,0 +1,83 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// add +//------------------------------------------------------------------------------ + +// general-purpose kernel for addition of two tensors +// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3 +// cons: not very efficient +kernel void kernel_add( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int i13 = i03 % ne13; + int i12 = i02 % ne12; + int i11 = i01 % ne11; + + global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const int i10 = i0 % ne10; + *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) + *((global float *)(src1_ptr + i10*nb10)); + } +} + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_add_row( + global float4 * src0, + ulong offset0, + global float4 * src1, + ulong offset1, + global float4 * dst, + ulong offsetd, + int ne +) { + src0 = (global float4*)((global char*)src0 + offset0); + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float4*)((global char*)dst + offsetd); + + // This performs better than using %. + uint gid = get_global_id(0); + uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne + dst[gid] = src0[gid] + src1[idx1]; +} diff --git a/ggml/src/ggml-opencl/kernels/argsort.cl b/ggml/src/ggml-opencl/kernels/argsort.cl new file mode 100644 index 000000000..af4adc7b8 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/argsort.cl @@ -0,0 +1,86 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define SWAP(x, y, T) { T tmp = (x); (x) = (y); (y) = tmp; } + +enum ggml_sort_order { + GGML_SORT_ORDER_ASC, + GGML_SORT_ORDER_DESC, +}; + +kernel void kernel_argsort_f32_i32( + global float * src0, + ulong offset0, + global int * dst, + ulong offsetd, + const int ne00, + const int ne00_pad, + const int order, + local int * dst_row +) { + // bitonic sort + int col = get_local_id(0); + int row = get_group_id(1); + + if (col >= ne00_pad) { + return; + } + + src0 = (global char *)((global char *)src0 + offset0); + dst = (global float *)((global char *)dst + offsetd); + + global float * x_row = src0 + row * ne00; + + // initialize indices + dst_row[col] = col; + + barrier(CLK_LOCAL_MEM_FENCE); + + for (int k = 2; k <= ne00_pad; k *= 2) { + for (int j = k / 2; j > 0; j /= 2) { + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (dst_row[col] >= ne00 || + (dst_row[ixj] < ne00 && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]])) + ) { + SWAP(dst_row[col], dst_row[ixj], int); + } + } else { + if (dst_row[ixj] >= ne00 || + (dst_row[col] < ne00 && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]])) + ) { + SWAP(dst_row[col], dst_row[ixj], int); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + } + + // copy the result to dst without the padding + if (col < ne00) { + dst[row * ne00 + col] = dst_row[col]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/clamp.cl b/ggml/src/ggml-opencl/kernels/clamp.cl new file mode 100644 index 000000000..ae6032444 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/clamp.cl @@ -0,0 +1,20 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// clamp +//------------------------------------------------------------------------------ +kernel void kernel_clamp( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + float min, + float max +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = src0[get_global_id(0)] < min ? + min : + (src0[get_global_id(0)] > max ? max : src0[get_global_id(0)]); +} diff --git a/ggml/src/ggml-opencl/kernels/concat.cl b/ggml/src/ggml-opencl/kernels/concat.cl new file mode 100644 index 000000000..132758469 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/concat.cl @@ -0,0 +1,109 @@ +kernel void kernel_concat_f32_contiguous( + global const char * p_src0, ulong off_src0, + global const char * p_src1, ulong off_src1, + global char * p_dst, ulong off_dst, + int d_ne00, int d_ne01, int d_ne02, // src0->ne[0..2] for the slice + int d_ne10, int d_ne11, int d_ne12, // src1->ne[0..2] for the slice (d_ne1X must match d_ne0X on non-concat axes) + int d_ne0, int d_ne1, int d_ne2, // dst->ne[0..2] for the slice + int dim +) { + global const float * src0 = (global const float*)((global char*)p_src0 + off_src0); + global const float * src1 = (global const float*)((global char*)p_src1 + off_src1); + global float * dst = (global float*)((global char*)p_dst + off_dst); + + int i0 = get_global_id(0); // Index along dst's 0th dimension + int i1 = get_global_id(1); // Index along dst's 1st dimension + int i2 = get_global_id(2); // Index along dst's 2nd dimension + + if (i0 >= d_ne0 || i1 >= d_ne1 || i2 >= d_ne2) { + return; + } + + ulong dst_idx = (ulong)i2 * d_ne0 * d_ne1 + (ulong)i1 * d_ne0 + i0; + ulong src_idx; + + if (dim == 0) { + if (i0 < d_ne00) { // Data from src0 + src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0; + dst[dst_idx] = src0[src_idx]; + } else { // Data from src1 + src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + (i0 - d_ne00); + dst[dst_idx] = src1[src_idx]; + } + } else if (dim == 1) { + if (i1 < d_ne01) { // Data from src0 + src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0; + dst[dst_idx] = src0[src_idx]; + } else { // Data from src1 + src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)(i1 - d_ne01) * d_ne10 + i0; + dst[dst_idx] = src1[src_idx]; + } + } else if (dim == 2) { + if (i2 < d_ne02) { // Data from src0 + src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0; + dst[dst_idx] = src0[src_idx]; + } else { // Data from src1 + + src_idx = (ulong)(i2 - d_ne02) * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + i0; + dst[dst_idx] = src1[src_idx]; + } + } +} + +kernel void kernel_concat_f32_non_contiguous( + global const char * p_src0, ulong off_src0, + global const char * p_src1, ulong off_src1, + global char * p_dst, ulong off_dst, + + long ne00, long ne01, long ne02, long ne03, + ulong nb00, ulong nb01, ulong nb02, ulong nb03, + + ulong nb10, ulong nb11, ulong nb12, ulong nb13, // Strides for src1 + + long d_ne0, long d_ne1, long d_ne2, long d_ne3, + ulong d_nb0, ulong d_nb1, ulong d_nb2, ulong d_nb3, + int dim +) { + global const char * src0_base = p_src0 + off_src0; + global const char * src1_base = p_src1 + off_src1; + global char * dst_base = p_dst + off_dst; + + long current_i1 = get_global_id(0); // Index for dst_dim_1 + long current_i2 = get_global_id(1); // Index for dst_dim_2 + long current_i3 = get_global_id(2); // Index for dst_dim_3 + + if (current_i1 >= d_ne1 || current_i2 >= d_ne2 || current_i3 >= d_ne3) { + return; + } + + global const float * x_val_ptr; + global float * y_val_ptr; + + for (long current_i0 = 0; current_i0 < d_ne0; ++current_i0) { + bool use_src0; + long s_i0 = current_i0, s_i1 = current_i1, s_i2 = current_i2, s_i3 = current_i3; + + if (dim == 0) { + use_src0 = (current_i0 < ne00); + if (!use_src0) { s_i0 = current_i0 - ne00; } + } else if (dim == 1) { + use_src0 = (current_i1 < ne01); + if (!use_src0) { s_i1 = current_i1 - ne01; } + } else if (dim == 2) { + use_src0 = (current_i2 < ne02); + if (!use_src0) { s_i2 = current_i2 - ne02; } + } else { // dim == 3 + use_src0 = (current_i3 < ne03); + if (!use_src0) { s_i3 = current_i3 - ne03; } + } + + if (use_src0) { + x_val_ptr = (global const float *)(src0_base + (ulong)s_i3*nb03 + (ulong)s_i2*nb02 + (ulong)s_i1*nb01 + (ulong)s_i0*nb00); + } else { + x_val_ptr = (global const float *)(src1_base + (ulong)s_i3*nb13 + (ulong)s_i2*nb12 + (ulong)s_i1*nb11 + (ulong)s_i0*nb10); + } + + y_val_ptr = (global float *)(dst_base + (ulong)current_i3*d_nb3 + (ulong)current_i2*d_nb2 + (ulong)current_i1*d_nb1 + (ulong)current_i0*d_nb0); + *y_val_ptr = *x_val_ptr; + } +} diff --git a/ggml/src/ggml-opencl/kernels/cpy.cl b/ggml/src/ggml-opencl/kernels/cpy.cl new file mode 100644 index 000000000..9369351a6 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/cpy.cl @@ -0,0 +1,184 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// cpy +//------------------------------------------------------------------------------ + +kernel void kernel_cpy_f16_f16( + global half * src0, + ulong offset0, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global const half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f16_f32( + global half * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + + src0 = (global half*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f32_f16( + global float * src0, + ulong offset0, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f32_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl similarity index 69% rename from ggml/src/ggml-opencl/kernels/ggml-opencl_cvt.cl rename to ggml/src/ggml-opencl/kernels/cvt.cl index e2024332f..fe7975e3d 100644 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -1,39 +1,20 @@ //------------------------------------------------------------------------------ -// This file is contains additional kernels for data conversion. +// This file is contains kernels for data conversion. // These kernels are used when loading the model, so its performance is less // important. //------------------------------------------------------------------------------ -#ifdef cl_khr_fp16 #pragma OPENCL EXTENSION cl_khr_fp16 : enable -#elif defined(cl_amd_fp16) -#pragma OPENCL EXTENSION cl_amd_fp16 : enable -#else -#error "Half precision floating point not supportedby OpenCL implementation on your device." -#endif - -#ifdef cl_khr_subgroups -#pragma OPENCL EXTENSION cl_khr_subgroups : enable -#elif defined(cl_intel_subgroups) -#pragma OPENCL EXTENSION cl_intel_subgroups : enable -#else -#error "Subgroup not supported on your device." -#endif #ifdef cl_intel_required_subgroup_size -// Always use subgroup size of 32 on Intel. #pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable #define INTEL_GPU 1 #define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) #define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) #elif defined(cl_qcom_reqd_sub_group_size) -// Always use subgroups size of 64 on Adreno. #pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable #define ADRENO_GPU 1 #define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) #define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) -#else -// TODO: do not know how to choose subgroup size on other GPUs. -#error "Selecting subgroup size is not supported on your device." #endif #define QK4_0 32 @@ -66,13 +47,44 @@ struct block_q4_0 }; //------------------------------------------------------------------------------ -// mul_vec_q_n_f32_flat_noshuffle -// -// This variation uses flat arrays (struct of arrays, SOA) representation for -// quant tensors. It also uses non shuffled bit order for weights. -// -// The shuffled version is kept in the original file because moving it here -// seems to result in worse performance for adreno. +// kernel_convert_block_q4_0 +// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q4_0( + global struct block_q4_0 * src0, + global uchar * dst_q, + global half * dst_d +) { + global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + + for (int i = 0; i < QK4_0/2; ++i) { + q[i] = b->qs[i]; + } +} + +kernel void kernel_restore_block_q4_0( + global uchar * src_q, + global half * src_d, + global struct block_q4_0 * dst +) { + global struct block_q4_0 * b = (global struct block_q4_0 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK4_0/2*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + for (int i = 0; i < QK4_0/2; ++i) { + b->qs[i] = q[i]; + } +} + +//------------------------------------------------------------------------------ +// kernel_convert_block_q4_0_noshuffle +// Flatten q4_0 weights and unshuffle the bits //------------------------------------------------------------------------------ kernel void kernel_convert_block_q4_0_noshuffle( diff --git a/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl b/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl new file mode 100644 index 000000000..36eff0439 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl @@ -0,0 +1,58 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// diag_mask_inf kernels +//------------------------------------------------------------------------------ +kernel void kernel_diag_mask_inf( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int n_past +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i02 = get_global_id(2); + int i01 = get_global_id(1); + int i00 = get_global_id(0); + + if (i00 > n_past + i01) { + dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; + } else { + dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; + } +} + +kernel void kernel_diag_mask_inf_8( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd, + int ne00, + int ne01, + int n_past +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + int i = 2*get_global_id(0); + + dst[i+0] = src0[i+0]; + dst[i+1] = src0[i+1]; + int i4 = 4*i; + int i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; + int i01 = i4/(ne00); i4 -= i01*ne00; + int i00 = i4; + for (int k = 3; k >= 0; --k) { + if (i00 + 4 + k <= n_past + i01) { + break; + } + (&dst[i+1])[k] = -INFINITY; + if (i00 + k > n_past + i01) { + (&dst[i])[k] = -INFINITY; + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/div.cl b/ggml/src/ggml-opencl/kernels/div.cl new file mode 100644 index 000000000..d453ad99b --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/div.cl @@ -0,0 +1,72 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// div +//------------------------------------------------------------------------------ +kernel void kernel_div( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int i13 = i03 % ne13; + int i12 = i02 % ne12; + int i11 = i01 % ne11; + + global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const int i10 = i0 % ne10; + *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) / *((global float *)(src1_ptr + i10*nb10)); + } +} + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_div_row( + global float4 * src0, + ulong offset0, + global float4 * src1, + ulong offset1, + global float4 * dst, + ulong offsetd, + int ne +) { + src0 = (global float4*)((global char*)src0 + offset0); + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float4*)((global char*)dst + offsetd); + + // This performs better than using %. + uint gid = get_global_id(0); + uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne + dst[gid] = src0[gid] / src1[idx1]; +} diff --git a/ggml/src/ggml-opencl/kernels/gelu.cl b/ggml/src/ggml-opencl/kernels/gelu.cl new file mode 100644 index 000000000..71c310cc9 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gelu.cl @@ -0,0 +1,62 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// gelu +//------------------------------------------------------------------------------ +#define GELU_COEF_A 0.044715f +#define GELU_QUICK_COEF -1.702f +#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f + +kernel void kernel_gelu( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + float x = src0[get_global_id(0)]; + + dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + float4 x = src0[get_global_id(0)]; + + dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_quick( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + float x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + +kernel void kernel_gelu_quick_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + float4 x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl similarity index 100% rename from ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl rename to ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl similarity index 100% rename from ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl rename to ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl diff --git a/ggml/src/ggml-opencl/kernels/get_rows.cl b/ggml/src/ggml-opencl/kernels/get_rows.cl new file mode 100644 index 000000000..b3fea2923 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/get_rows.cl @@ -0,0 +1,163 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +#define QK4_0 32 + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + + +//------------------------------------------------------------------------------ +// dequantize_q4_0_f32, dequantize_q4_0_f16 +//------------------------------------------------------------------------------ +void dequantize_q4_0_f32(global struct block_q4_0 * xb, short il, float16 * reg) { + global ushort * qs = ((global ushort *)xb + 1); + float d1 = il ? (xb->d / 16.h) : xb->d; + float d2 = d1 / 256.f; + float md = -8.h * xb->d; + ushort mask0 = il ? 0x00F0 : 0x000F; + ushort mask1 = mask0 << 8; + + reg->s0 = d1 * (qs[0] & mask0) + md; + reg->s1 = d2 * (qs[0] & mask1) + md; + + reg->s2 = d1 * (qs[1] & mask0) + md; + reg->s3 = d2 * (qs[1] & mask1) + md; + + reg->s4 = d1 * (qs[2] & mask0) + md; + reg->s5 = d2 * (qs[2] & mask1) + md; + + reg->s6 = d1 * (qs[3] & mask0) + md; + reg->s7 = d2 * (qs[3] & mask1) + md; + + reg->s8 = d1 * (qs[4] & mask0) + md; + reg->s9 = d2 * (qs[4] & mask1) + md; + + reg->sa = d1 * (qs[5] & mask0) + md; + reg->sb = d2 * (qs[5] & mask1) + md; + + reg->sc = d1 * (qs[6] & mask0) + md; + reg->sd = d2 * (qs[6] & mask1) + md; + + reg->se = d1 * (qs[7] & mask0) + md; + reg->sf = d2 * (qs[7] & mask1) + md; +} + + +//------------------------------------------------------------------------------ +// get_rows +//------------------------------------------------------------------------------ +kernel void kernel_get_rows_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb10, + ulong nb11, + ulong nb1, + ulong nb2 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i10 = get_group_id(0); + int i11 = get_group_id(1); + + int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + + int i02 = i11; + + for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { + ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = + ((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + +kernel void kernel_get_rows_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb10, + ulong nb11, + ulong nb1, + ulong nb2 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i10 = get_group_id(0); + int i11 = get_group_id(1); + + int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + + int i02 = i11; + + for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { + ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = + ((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + +kernel void kernel_get_rows_q4_0( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb10, + ulong nb11, + ulong nb1, + ulong nb2 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + const int NL = 2; + + int i10 = get_group_id(0); + int i11 = get_group_id(1); + + int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + + int i02 = i11; + + for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) { + float16 temp; + dequantize_q4_0_f32( + ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp); + *(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; + } +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl deleted file mode 100644 index b88792887..000000000 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl +++ /dev/null @@ -1,3231 +0,0 @@ -#ifdef cl_khr_fp16 -#pragma OPENCL EXTENSION cl_khr_fp16 : enable -#elif defined(cl_amd_fp16) -#pragma OPENCL EXTENSION cl_amd_fp16 : enable -#else -#error "Half precision floating point not supportedby OpenCL implementation on your device." -#endif - -#ifdef cl_khr_subgroups -#pragma OPENCL EXTENSION cl_khr_subgroups : enable -#elif defined(cl_intel_subgroups) -#pragma OPENCL EXTENSION cl_intel_subgroups : enable -#else -#error "Subgroup not supported on your device." -#endif - -#ifdef cl_intel_required_subgroup_size -// Always use subgroup size of 32 on Intel. -#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable -#define INTEL_GPU 1 -#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) -#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) -#elif defined(cl_qcom_reqd_sub_group_size) -// Always use subgroups size of 64 on Adreno. -#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable -#define ADRENO_GPU 1 -#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) -#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) -#else -// TODO: do not know how to choose subgroup size on other GPUs. -#error "Selecting subgroup size is not supported on your device." -#endif - -#define QK4_0 32 -#define QR4_0 2 -#define QK4_1 32 -#define QR4_1 2 -#define QK5_0 32 -#define QR5_0 2 -#define QK5_1 32 -#define QR5_1 2 -#define QK8_0 32 -#define QR8_0 1 -#define QK_K 256 -#define K_QUANTS_PER_ITERATION 2 - -typedef char int8_t; -typedef uchar uint8_t; -typedef short int16_t; -typedef ushort uint16_t; -typedef int int32_t; -typedef uint uint32_t; - -//------------------------------------------------------------------------------ -// block_q4_0 -//------------------------------------------------------------------------------ -struct block_q4_0 -{ - half d; - uint8_t qs[QK4_0 / 2]; -}; - -//------------------------------------------------------------------------------ -// block_q4_1 -//------------------------------------------------------------------------------ -struct block_q4_1 -{ - half d; - half m; - uint8_t qs[QK4_1 / 2]; -}; - -//------------------------------------------------------------------------------ -// block_q5_0 -//------------------------------------------------------------------------------ -struct block_q5_0 -{ - half d; - uint32_t qh; - uint8_t qs[QK5_0 / 2]; -}; - -//------------------------------------------------------------------------------ -// block_q5_1 -//------------------------------------------------------------------------------ -struct block_q5_1 -{ - half d; - half m; - uint32_t qh; - uint8_t qs[QK5_1 / 2]; -}; - -//------------------------------------------------------------------------------ -// block_q8_0 -//------------------------------------------------------------------------------ -struct block_q8_0 -{ - half d; - int8_t qs[QK8_0]; -}; - -//------------------------------------------------------------------------------ -// block_q2_K -//------------------------------------------------------------------------------ -struct block_q2_K -{ - uint8_t scales[16]; - uint8_t qs[64]; - half d; - half dmin; -}; - -//------------------------------------------------------------------------------ -// block_q3_K -//------------------------------------------------------------------------------ -struct block_q3_K -{ - uint8_t hmask[32]; - uint8_t qs[64]; - uint8_t scales[12]; - half d; -}; - -//------------------------------------------------------------------------------ -// block_q4_K -//------------------------------------------------------------------------------ -struct block_q4_K -{ - half d; - half dmin; - uint8_t scales[12]; - uint8_t qs[128]; -}; - -//------------------------------------------------------------------------------ -// block_q5_K -//------------------------------------------------------------------------------ -struct block_q5_K -{ - half d; - half dmin; - uint8_t scales[12]; - uint8_t qh[32]; - uint8_t qs[128]; -}; - -//------------------------------------------------------------------------------ -// block_q6_K -//------------------------------------------------------------------------------ -struct block_q6_K -{ - uint8_t ql[128]; - uint8_t qh[64]; - int8_t scales[16]; - half d; -}; - -//------------------------------------------------------------------------------ -// dequantize_q4_0_f32, dequantize_q4_0_f16 -//------------------------------------------------------------------------------ -void dequantize_q4_0_f32(global struct block_q4_0 * xb, short il, float16 * reg) { - global ushort * qs = ((global ushort *)xb + 1); - float d1 = il ? (xb->d / 16.h) : xb->d; - float d2 = d1 / 256.f; - float md = -8.h * xb->d; - ushort mask0 = il ? 0x00F0 : 0x000F; - ushort mask1 = mask0 << 8; - - reg->s0 = d1 * (qs[0] & mask0) + md; - reg->s1 = d2 * (qs[0] & mask1) + md; - - reg->s2 = d1 * (qs[1] & mask0) + md; - reg->s3 = d2 * (qs[1] & mask1) + md; - - reg->s4 = d1 * (qs[2] & mask0) + md; - reg->s5 = d2 * (qs[2] & mask1) + md; - - reg->s6 = d1 * (qs[3] & mask0) + md; - reg->s7 = d2 * (qs[3] & mask1) + md; - - reg->s8 = d1 * (qs[4] & mask0) + md; - reg->s9 = d2 * (qs[4] & mask1) + md; - - reg->sa = d1 * (qs[5] & mask0) + md; - reg->sb = d2 * (qs[5] & mask1) + md; - - reg->sc = d1 * (qs[6] & mask0) + md; - reg->sd = d2 * (qs[6] & mask1) + md; - - reg->se = d1 * (qs[7] & mask0) + md; - reg->sf = d2 * (qs[7] & mask1) + md; -} - -void dequantize_q4_0_f16(global struct block_q4_0 * xb, short il, half16 * reg) { - global ushort * qs = ((global ushort *)xb + 1); - half d1 = il ? (xb->d / 16.h) : xb->d; - half d2 = d1 / 256.h; - half md = -8.h * xb->d; - ushort mask0 = il ? 0x00F0 : 0x000F; - ushort mask1 = mask0 << 8; - - reg->s0 = d1 * (qs[0] & mask0) + md; - reg->s1 = d2 * (qs[0] & mask1) + md; - - reg->s2 = d1 * (qs[1] & mask0) + md; - reg->s3 = d2 * (qs[1] & mask1) + md; - - reg->s4 = d1 * (qs[2] & mask0) + md; - reg->s5 = d2 * (qs[2] & mask1) + md; - - reg->s6 = d1 * (qs[3] & mask0) + md; - reg->s7 = d2 * (qs[3] & mask1) + md; - - reg->s8 = d1 * (qs[4] & mask0) + md; - reg->s9 = d2 * (qs[4] & mask1) + md; - - reg->sa = d1 * (qs[5] & mask0) + md; - reg->sb = d2 * (qs[5] & mask1) + md; - - reg->sc = d1 * (qs[6] & mask0) + md; - reg->sd = d2 * (qs[6] & mask1) + md; - - reg->se = d1 * (qs[7] & mask0) + md; - reg->sf = d2 * (qs[7] & mask1) + md; -} - -//------------------------------------------------------------------------------ -// add -//------------------------------------------------------------------------------ - -// general-purpose kernel for addition of two tensors -// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3 -// cons: not very efficient -kernel void kernel_add( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global char * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - int ne13, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3 -) { - src0 = src0 + offset0; - src1 = src1 + offset1; - dst = dst + offsetd; - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - int i13 = i03 % ne13; - int i12 = i02 % ne12; - int i11 = i01 % ne11; - - global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; - - for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { - const int i10 = i0 % ne10; - *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) + *((global float *)(src1_ptr + i10*nb10)); - } -} - -// assumption: src1 is a row -// broadcast src1 into src0 -kernel void kernel_add_row( - global float4 * src0, - ulong offset0, - global float4 * src1, - ulong offset1, - global float4 * dst, - ulong offsetd, - int ne -) { - src0 = (global float4*)((global char*)src0 + offset0); - src1 = (global float4*)((global char*)src1 + offset1); - dst = (global float4*)((global char*)dst + offsetd); - - // This performs better than using %. - uint gid = get_global_id(0); - uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne - dst[gid] = src0[gid] + src1[idx1]; -} - -//------------------------------------------------------------------------------ -// mul -//------------------------------------------------------------------------------ -kernel void kernel_mul( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global char * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - int ne13, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3 -) { - src0 = src0 + offset0; - src1 = src1 + offset1; - dst = dst + offsetd; - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - int i13 = i03 % ne13; - int i12 = i02 % ne12; - int i11 = i01 % ne11; - - global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; - - for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { - const int i10 = i0 % ne10; - *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) * *((global float *)(src1_ptr + i10*nb10)); - } -} - -// assumption: src1 is a row -// broadcast src1 into src0 -kernel void kernel_mul_row( - global float4 * src0, - ulong offset0, - global float4 * src1, - ulong offset1, - global float4 * dst, - ulong offsetd, - int ne -) { - src0 = (global float4*)((global char*)src0 + offset0); - src1 = (global float4*)((global char*)src1 + offset1); - dst = (global float4*)((global char*)dst + offsetd); - - // This performs better than using %. - uint gid = get_global_id(0); - uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne - dst[gid] = src0[gid] * src1[idx1]; -} - -//------------------------------------------------------------------------------ -// scale -//------------------------------------------------------------------------------ -kernel void kernel_scale( - global float4 * src0, - ulong offset0, - global float4 * dst, - ulong offsetd, - float scale -) { - src0 = (global float4*)((global char*)src0 + offset0); - dst = (global float4*)((global char*)dst + offsetd); - dst[get_global_id(0)] = src0[get_global_id(0)] * scale; -} - -//------------------------------------------------------------------------------ -// gelu -//------------------------------------------------------------------------------ -#define GELU_COEF_A 0.044715f -#define GELU_QUICK_COEF -1.702f -#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f - -kernel void kernel_gelu( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - float x = src0[get_global_id(0)]; - - dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -kernel void kernel_gelu_4( - global float4 * src0, - ulong offset0, - global float4 * dst, - ulong offsetd -) { - src0 = (global float4*)((global char*)src0 + offset0); - dst = (global float4*)((global char*)dst + offsetd); - - float4 x = src0[get_global_id(0)]; - - dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -kernel void kernel_gelu_quick( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - float x = src0[get_global_id(0)]; - dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); -} - -kernel void kernel_gelu_quick_4( - global float4 * src0, - ulong offset0, - global float4 * dst, - ulong offsetd -) { - src0 = (global float4*)((global char*)src0 + offset0); - dst = (global float4*)((global char*)dst + offsetd); - - float4 x = src0[get_global_id(0)]; - dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); -} - -//------------------------------------------------------------------------------ -// silu -//------------------------------------------------------------------------------ -kernel void kernel_silu( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - float x = src0[get_global_id(0)]; - dst[get_global_id(0)] = x / (1.0f + exp(-x)); -} - -kernel void kernel_silu_4( - global float4 * src0, - ulong offset0, - global float4 * dst, - ulong offsetd -) { - src0 = (global float4*)((global char*)src0 + offset0); - dst = (global float4*)((global char*)dst + offsetd); - - float4 x = src0[get_global_id(0)]; - dst[get_global_id(0)] = x / (1.0f + exp(-x)); -} - -//------------------------------------------------------------------------------ -// relu -//------------------------------------------------------------------------------ -kernel void kernel_relu( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - dst[get_global_id(0)] = fmax(0.0f, src0[get_global_id(0)]); -} - -//------------------------------------------------------------------------------ -// clamp -//------------------------------------------------------------------------------ -kernel void kernel_clamp( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd, - float min, - float max -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - dst[get_global_id(0)] = src0[get_global_id(0)] < min ? - min : - (src0[get_global_id(0)] > max ? max : src0[get_global_id(0)]); -} - -//------------------------------------------------------------------------------ -// norm -//------------------------------------------------------------------------------ -kernel void kernel_norm( - global void * src0, - ulong offset0, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb01, - ulong nb02, - ulong nb03, - float eps, - local float * sum -) { - src0 = (global void*)((global char*)src0 + offset0); - dst = (global void*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - global float * x = (global float *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01); - - // MEAN - // parallel sum - sum[get_local_id(0)] = 0.0f; - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - sum[get_local_id(0)] += x[i00]; - } - // reduce - barrier(CLK_LOCAL_MEM_FENCE); - for (uint i = get_local_size(0)/2; i > 0; i /= 2) { - if (get_local_id(0) < i) { - sum[get_local_id(0)] += sum[get_local_id(0) + i]; - } - barrier(CLK_LOCAL_MEM_FENCE); - } - float mean = sum[0] / ne00; - - // recenter and VARIANCE - barrier(CLK_LOCAL_MEM_FENCE); - global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - sum[get_local_id(0)] = 0.0f; - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - y[i00] = x[i00] - mean; - sum[get_local_id(0)] += y[i00] * y[i00]; - } - - // reduce - barrier(CLK_LOCAL_MEM_FENCE); - for (uint i = get_local_size(0)/2; i > 0; i /= 2) { - if (get_local_id(0) < i) { - sum[get_local_id(0)] += sum[get_local_id(0) + i]; - } - barrier(CLK_LOCAL_MEM_FENCE); - } - float variance = sum[0] / ne00; - - float scale = 1.0f/sqrt(variance + eps); - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - y[i00] = y[i00] * scale; - } -} - -//------------------------------------------------------------------------------ -// rms_norm -//------------------------------------------------------------------------------ -// This kernel depends on subgroup size. -kernel void kernel_rms_norm( - global void * src0, - ulong offset0, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb01, - ulong nb02, - ulong nb03, - float eps, - local float * sum // Note, the size depends on number of subgroups -) { - src0 = (global void*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - global float4 * x = (global float4 *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01); - global float * x_scalar = (global float *) x; - float4 sumf = 0; - float all_sum = 0; - - // parallel sum - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - sumf += x[i00] * x[i00]; - } - all_sum = sumf.s0 + sumf.s1 + sumf.s2 + sumf.s3; - all_sum = sub_group_reduce_add(all_sum); - if (get_sub_group_local_id() == 0) { - sum[get_sub_group_id()] = all_sum; - } - - barrier(CLK_LOCAL_MEM_FENCE); - // broadcast - for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) { - if (get_local_id(0) < i) { - sum[get_local_id(0)] += sum[get_local_id(0) + i]; - } - } - if (get_local_id(0) == 0) { - for (int i = 4 * (ne00 / 4); i < ne00; i++) { - sum[0] += x_scalar[i]; - } - sum[0] /= ne00; - } - - barrier(CLK_LOCAL_MEM_FENCE); - - const float mean = sum[0]; - const float scale = 1.0f/sqrt(mean + eps); - - global float4 * y = (global float4 *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - global float * y_scalar = (global float *) y; - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - y[i00] = x[i00] * scale; - } - if (get_local_id(0) == 0) { - for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) { - y_scalar[i00] = x_scalar[i00] * scale; - } - } -} - -//------------------------------------------------------------------------------ -// diag_mask_inf kernels -//------------------------------------------------------------------------------ -kernel void kernel_diag_mask_inf( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int n_past -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - int i02 = get_global_id(2); - int i01 = get_global_id(1); - int i00 = get_global_id(0); - - if (i00 > n_past + i01) { - dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; - } else { - dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; - } -} - -kernel void kernel_diag_mask_inf_8( - global float4 * src0, - ulong offset0, - global float4 * dst, - ulong offsetd, - int ne00, - int ne01, - int n_past -) { - src0 = (global float4*)((global char*)src0 + offset0); - dst = (global float4*)((global char*)dst + offsetd); - - int i = 2*get_global_id(0); - - dst[i+0] = src0[i+0]; - dst[i+1] = src0[i+1]; - int i4 = 4*i; - int i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; - int i01 = i4/(ne00); i4 -= i01*ne00; - int i00 = i4; - for (int k = 3; k >= 0; --k) { - if (i00 + 4 + k <= n_past + i01) { - break; - } - (&dst[i+1])[k] = -INFINITY; - if (i00 + k > n_past + i01) { - (&dst[i])[k] = -INFINITY; - } - } -} - -//------------------------------------------------------------------------------ -// softmax -//------------------------------------------------------------------------------ -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_soft_max( - global float * src0, - ulong offset0, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - float scale, - float max_bias, - float m0, - float m1, - int n_head_log2 -) { - src0 = (global float*)((global char*)src0 + offset0); - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0; - global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - int h = i02; - - float base = h < n_head_log2 ? m0 : m1; - int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - // parallel max - float lmax = -INFINITY; - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); - } - float max = sub_group_reduce_max(lmax); - - // parallel sum - float lsum = 0.0f; - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); - lsum += exp_psrc0; - // Remember the result of exp here. exp is expensive, so we really do not - // wish to compute it twice. - pdst[i00] = exp_psrc0; - } - - const float sum = sub_group_reduce_add(lsum); - - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - pdst[i00] /= sum; - } -} - -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_soft_max_4( - global float * src0, - ulong offset0, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - float scale, - float max_bias, - float m0, - float m1, - int n_head_log2 -) { - src0 = (global float*)((global char*)src0 + offset0); - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0; - global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - int h = i02; - - float base = h < n_head_log2 ? m0 : m1; - int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - // parallel max - float4 lmax4 = -INFINITY; - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); - } - float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3)); - - const float max = sub_group_reduce_max(lmax); - - // parallel sum - float4 lsum4 = 0.0f; - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); - lsum4 += exp_psrc4; - pdst4[i00] = exp_psrc4; - } - float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; - - const float sum = sub_group_reduce_add(lsum); - - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - pdst4[i00] /= sum; - } -} - -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_soft_max_f16( - global float * src0, - ulong offset0, - global half * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - float scale, - float max_bias, - float m0, - float m1, - int n_head_log2 -) { - src0 = (global float *)((global char *)src0 + offset0); - src1 = (global half *)((global char *)src1 + offset1); - dst = (global float *)((global char *)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - global half * pmask = (global char *)src1 != (global char *)src0 ? src1 + i01*ne00 : 0; - global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - int h = i02; - - float base = h < n_head_log2 ? m0 : m1; - int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - // parallel max - float lmax = -INFINITY; - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); - } - float max = sub_group_reduce_max(lmax); - - // parallel sum - float lsum = 0.0f; - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); - lsum += exp_psrc0; - // Remember the result of exp here. exp is expensive, so we really do not - // wish to compute it twice. - pdst[i00] = exp_psrc0; - } - - const float sum = sub_group_reduce_add(lsum); - - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - pdst[i00] /= sum; - } -} - -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_soft_max_4_f16( - global float * src0, - ulong offset0, - global half * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - float scale, - float max_bias, - float m0, - float m1, - int n_head_log2 -) { - src0 = (global float *)((global char *)src0 + offset0); - src1 = (global half *)((global char *)src1 + offset1); - dst = (global float *)((global char *)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - global half4 * pmask = (global char *)src1 != (global char *)src0 ? (global half4 *)(src1 + i01*ne00) : 0; - global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - int h = i02; - - float base = h < n_head_log2 ? m0 : m1; - int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - // parallel max - float4 lmax4 = -INFINITY; - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)); - } - float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3)); - - const float max = sub_group_reduce_max(lmax); - - // parallel sum - float4 lsum4 = 0.0f; - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)) - max); - lsum4 += exp_psrc4; - pdst4[i00] = exp_psrc4; - } - float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; - - const float sum = sub_group_reduce_add(lsum); - - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - pdst4[i00] /= sum; - } -} - -//------------------------------------------------------------------------------ -// kernel_rope -//------------------------------------------------------------------------------ -float rope_yarn_ramp(float low, float high, int i0) { - const float y = (i0 / 2 - low) / max(0.001f, high - low); - return 1.0f - min(1.0f, max(0.0f, y)); -} - -// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn -// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. -float2 rope_yarn( - float theta_extrap, float freq_scale, float2 corr_dims, int i0, float ext_factor, float mscale -) { - // Get n-d rotational scaling corrected for extrapolation - float theta_interp = freq_scale * theta_extrap; - float theta = theta_interp; - if (ext_factor != 0.0f) { - float ramp_mix = rope_yarn_ramp(corr_dims.s0, corr_dims.s1, i0) * ext_factor; - theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; - - // Get n-d magnitude scaling corrected for interpolation - mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); - } - return (float2)(cos(theta) * mscale, sin(theta) * mscale); -} - -// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get -// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) { - return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base)); -} - -float2 rope_yarn_corr_dims( - int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow -) { - // start and end correction dims - return (float2)( - max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))), - min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))) - ); -} - -kernel void kernel_rope_norm_f32( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - float theta_base = (float) pos[i2]; - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - if (i0 < n_dims) { - int ic = i0/2; - - float theta = theta_base * pow(freq_base, inv_ndims*i0); - - float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - float x0 = src[0]; - float x1 = src[1]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } else { - global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } -} - -kernel void kernel_rope_norm_f16( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - float theta_base = (float) pos[i2]; - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - if (i0 < n_dims) { - int ic = i0/2; - - float theta = theta_base * pow(freq_base, inv_ndims*i0); - - float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - float x0 = src[0]; - float x1 = src[1]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } else { - global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } -} - -kernel void kernel_rope_neox_f32( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - float theta_base = (float) pos[i2]; - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - if (i0 < n_dims) { - int ic = i0/2; - - const float theta = theta_base * pow(freq_base, inv_ndims*i0); - - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } else { - global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } -} - -kernel void kernel_rope_neox_f16( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - float theta_base = (float) pos[i2]; - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - if (i0 < n_dims) { - int ic = i0/2; - - const float theta = theta_base * pow(freq_base, inv_ndims*i0); - - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } else { - global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } -} - -kernel void kernel_rope_multi_f32( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow, - int4 sections -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3; - const int sec_w = sections.s1 + sections.s0; - - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - if (i0 < n_dims) { - int ic = i0/2; - - const int sector = (i0 / 2) % sect_dims; - float theta_base = 0.0f; - - if (sector < sections.s0) { - theta_base = pos[i2]; - } - else if (sector >= sections.s0 && sector < sec_w) { - theta_base = pos[i2 + ne2 * 1]; - } - else if (sector >= sec_w && sector < sec_w + sections.s2) { - theta_base = pos[i2 + ne2 * 2]; - } - else if (sector >= sec_w + sections.s2) { - theta_base = pos[i2 + ne2 * 3]; - } - - const float theta = theta_base * pow(freq_base, inv_ndims*i0); - - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } else { - global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } -} - -kernel void kernel_rope_multi_f16( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global half * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow, - int4 sections -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3; - const int sec_w = sections.s1 + sections.s0; - - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - if (i0 < n_dims) { - int ic = i0/2; - - const int sector = (i0 / 2) % sect_dims; - float theta_base = 0.0f; - - if (sector < sections.s0) { - theta_base = pos[i2]; - } - else if (sector >= sections.s0 && sector < sec_w) { - theta_base = pos[i2 + ne2 * 1]; - } - else if (sector >= sec_w && sector < sec_w + sections.s2) { - theta_base = pos[i2 + ne2 * 2]; - } - else if (sector >= sec_w + sections.s2) { - theta_base = pos[i2 + ne2 * 3]; - } - - const float theta = theta_base * pow(freq_base, inv_ndims*i0); - - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } else { - global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } -} - -kernel void kernel_rope_vision_f32( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow, - int4 sections -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - const int sect_dims = sections.s0 + sections.s1; - const int sec_w = sections.s1 + sections.s0; - - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - int ic = i0/2; - - const int sector = (i0/2) % sect_dims; - float theta_base = 0.0f; - - if (sector < sections.s0) { - const int p = sector; - theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p); - } else if (sector >= sections.s0 && sector < sec_w) { - const int p = sector - sections.s0; - theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p); - } - - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } -} - -kernel void kernel_rope_vision_f16( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global half * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow, - int4 sections -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - const int sect_dims = sections.s0 + sections.s1; - const int sec_w = sections.s1 + sections.s0; - - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - int ic = i0/2; - - const int sector = (i0/2) % sect_dims; - float theta_base = 0.0f; - - if (sector < sections.s0) { - const int p = sector; - theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p); - } else if (sector >= sections.s0 && sector < sec_w) { - const int p = sector - sections.s0; - theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p); - } - - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } -} - -//------------------------------------------------------------------------------ -// cpy -//------------------------------------------------------------------------------ - -kernel void kernel_cpy_f16_f16( - global half * src0, - ulong offset0, - global half * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3 -) { - src0 = (global half*)((global char*)src0 + offset0); - dst = (global half*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - int i3 = n / (ne2*ne1*ne0); - int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - global const half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f16_f32( - global half * src0, - ulong offset0, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3 -) { - - src0 = (global half*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - int i3 = n / (ne2*ne1*ne0); - int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - global half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f32_f16( - global float * src0, - ulong offset0, - global half * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3 -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global half*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - int i3 = n / (ne2*ne1*ne0); - int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f32_f32( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3 -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - int i3 = n / (ne2*ne1*ne0); - int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - dst_data[i00] = src[0]; - } -} - -//------------------------------------------------------------------------------ -// get_rows -//------------------------------------------------------------------------------ -kernel void kernel_get_rows_f32( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - ulong nb01, - ulong nb02, - int ne10, - ulong nb10, - ulong nb11, - ulong nb1, - ulong nb2 -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int i10 = get_group_id(0); - int i11 = get_group_id(1); - - int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; - - int i02 = i11; - - for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { - ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = - ((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; - } -} - -kernel void kernel_get_rows_f16( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - ulong nb01, - ulong nb02, - int ne10, - ulong nb10, - ulong nb11, - ulong nb1, - ulong nb2 -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int i10 = get_group_id(0); - int i11 = get_group_id(1); - - int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; - - int i02 = i11; - - for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { - ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = - ((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; - } -} - -kernel void kernel_get_rows_q4_0( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - ulong nb01, - ulong nb02, - int ne10, - ulong nb10, - ulong nb11, - ulong nb1, - ulong nb2 -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - const int NL = 2; - - int i10 = get_group_id(0); - int i11 = get_group_id(1); - - int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; - - int i02 = i11; - - for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) { - float16 temp; - dequantize_q4_0_f32( - ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp); - *(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; - } -} - -//------------------------------------------------------------------------------ -// mul_mat_f32_f32 -//------------------------------------------------------------------------------ -#define N_F32_F32 4 - -kernel void kernel_mul_mat_f32_f32( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global char*)((global char*)src0 + offset0); - src1 = (global char*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int r0 = get_group_id(0); - int rb = get_group_id(1)*N_F32_F32; - int im = get_group_id(2); - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - - global float * x = (global float *) (src0 + offset_src0); - - if (ne00 < 128) { - for (int row = 0; row < N_F32_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global float * y = (global float *) (src1 + offset_src1); - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { - sumf += (float) x[i] * (float) y[i]; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - global float4 * x4 = (global float4 *)x; - for (int row = 0; row < N_F32_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global float * y = (global float *) (src1 + offset_src1); - global float4 * y4 = (global float4 *) y; - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { - sumf += (float) x4[i].s0 * y4[i].s0; - sumf += (float) x4[i].s1 * y4[i].s1; - sumf += (float) x4[i].s2 * y4[i].s2; - sumf += (float) x4[i].s3 * y4[i].s3; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) { - all_sum += (float) x[i] * y[i]; - } - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -//------------------------------------------------------------------------------ -// mul_mat_f16_f16 -//------------------------------------------------------------------------------ -#define N_F16_F16 4 - -kernel void kernel_mul_mat_f16_f16( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int r2, - int r3) -{ - src0 = (global char*)((global char*)src0 + offset0); - src1 = (global char*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int r0 = get_group_id(0); - int rb = get_group_id(1)*N_F16_F16; - int im = get_group_id(2); - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - - global half * x = (global half *) (src0 + offset_src0); - - if (ne00 < 128) { - for (int row = 0; row < N_F16_F16; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global half * y = (global half *) (src1 + offset_src1); - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { - sumf += (half) x[i] * (half) y[i]; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - global half4 * x4 = (global half4 *)x; - for (int row = 0; row < N_F16_F16; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global half * y = (global half *) (src1 + offset_src1); - global half4 * y4 = (global half4 *) y; - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { - sumf += (half) x4[i].s0 * y4[i].s0; - sumf += (half) x4[i].s1 * y4[i].s1; - sumf += (half) x4[i].s2 * y4[i].s2; - sumf += (half) x4[i].s3 * y4[i].s3; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) { - all_sum += (half) x[i] * y[i]; - } - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -//------------------------------------------------------------------------------ -// mul_mat_f16_f32_1row -//------------------------------------------------------------------------------ -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_f16_f32_1row( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global char*)((global char*)src0 + offset0); - src1 = (global char*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global half * x = (global half *) (src0 + offset_src0); - global float * y = (global float *) (src1 + offset_src1); - - float sumf = 0; - if (ne00 < 128) { - for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { - sumf += (float) x[i] * (float) y[i]; - } - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } else { - global half4 * x4 = (global half4 *) x; - global float4 * y4 = (global float4 *) y; - for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { - sumf += (float) x4[i].s0 * y4[i].s0; - sumf += (float) x4[i].s1 * y4[i].s1; - sumf += (float) x4[i].s2 * y4[i].s2; - sumf += (float) x4[i].s3 * y4[i].s3; - } - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) { - all_sum += (float) x[i] * y[i]; - } - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - -} - -//------------------------------------------------------------------------------ -// mul_mat_f16_f32 -//------------------------------------------------------------------------------ -#define N_F16_F32 4 - -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_f16_f32( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global char*)((global char*)src0 + offset0); - src1 = (global char*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int r0 = get_group_id(0); - int rb = get_group_id(1)*N_F16_F32; - int im = get_group_id(2); - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - - global half * x = (global half *) (src0 + offset_src0); - - if (ne00 < 128) { - for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global float * y = (global float *) (src1 + offset_src1); - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { - sumf += convert_float(x[i]) * y[i]; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - global half4 * x4 = (global half4 *)x; - for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global float * y = (global float *) (src1 + offset_src1); - global float4 * y4 = (global float4 *) y; - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { - sumf += convert_float(x4[i].s0) * y4[i].s0; - sumf += convert_float(x4[i].s1) * y4[i].s1; - sumf += convert_float(x4[i].s2) * y4[i].s2; - sumf += convert_float(x4[i].s3) * y4[i].s3; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) { - all_sum += (float) x[i] * y[i]; - } - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -//------------------------------------------------------------------------------ -// mul_mat_f16_f32_l4 -//------------------------------------------------------------------------------ -// Assumes row size (ne00) is a multiple of 4 -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_f16_f32_l4( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global char*)((global char*)src0 + offset0); - src1 = (global char*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int nrows = ne11; - int r0 = get_group_id(0); - int im = get_group_id(2); - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - - global half4 * x4 = (global half4 *) (src0 + offset_src0); - - for (int r1 = 0; r1 < nrows; ++r1) { - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global float4 * y4 = (global float4 *) (src1 + offset_src1); - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { - sumf += convert_float(x4[i].s0) * y4[i].s0; - sumf += convert_float(x4[i].s1) * y4[i].s1; - sumf += convert_float(x4[i].s2) * y4[i].s2; - sumf += convert_float(x4[i].s3) * y4[i].s3; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } -} - -//------------------------------------------------------------------------------ -// mul_vec_q_n_f32 -//------------------------------------------------------------------------------ -// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q4 quants begin (0 or QK4_0/4) -// we assume that the yl's have been multiplied with the appropriate scale factor -// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) -inline float block_q_4_0_dot_y( - global struct block_q4_0 * qb_curr, - float sumy, - private float * yl, - int il -) { - float d = qb_curr->d; - float2 acc = 0.f; - global ushort * qs = ((global ushort *)qb_curr + 1 + il/2); - for (int i = 0; i < 8; i+=2) { - acc.s0 += yl[i + 0] * (qs[i / 2] & 0x000F) - + yl[i + 1] * (qs[i / 2] & 0x0F00); - acc.s1 += yl[i + 8] * (qs[i / 2] & 0x00F0) - + yl[i + 9] * (qs[i / 2] & 0xF000); - } - return d * (sumy * -8.f + acc.s0 + acc.s1); -} - -#ifdef INTEL_GPU -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 1 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 -#elif defined (ADRENO_GPU) -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif - -inline void mul_vec_q_n_f32( - global void * src0, - global float * src1, - global float * dst, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - - const ulong nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global - // id of a SIMD group in the grid. - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[16]; // src1 vector cache - float sumf[N_DST]={0.f}; - - int ix = get_sub_group_local_id()/2; - int il = 8*(get_sub_group_local_id()%2); - - global float * yb = y + ix * QK4_0 + il; - - // each thread in a SIMD group deals with half a block. - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { - float sumy = 0; - for (int i = 0; i < 8; i += 2) { - sumy += yb[i] + yb[i+1]; - yl[i+0] = yb[i+ 0]; - yl[i+1] = yb[i+ 1]/256.f; - sumy += yb[i+16] + yb[i+17]; - yl[i+8] = yb[i+16]/16.f; - yl[i+9] = yb[i+17]/4096.f; - } - - for (int row = 0; row < N_DST; row++) { - sumf[row] += block_q_4_0_dot_y(x+ib+row*nb, sumy, yl, il); - } - - // One thread in a SIMD group (i.e., subgroup) handles a half block, - // hence then entire SIMD group handles SIMDWIDTH/2 blocks. - // y points to the activation matrix (of type float). Therefore for - // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because - // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of - // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size. - yb += QK4_0 * (N_SIMDWIDTH/2); - } - - // The above does not work for Adreno - it produces incorrect results for - // row = 1, 2, 3 and only row = 0 gives the correct result. - // If N_DST is changed, the below array must be initialized accordingly. - // This also seems to perform better on Intel. - float tot[N_DST] = { - sub_group_reduce_add(sumf[0]), sub_group_reduce_add(sumf[1]), - sub_group_reduce_add(sumf[2]), sub_group_reduce_add(sumf[3])}; - for (int row = 0; row < N_DST; ++row) { - if (get_sub_group_local_id() == 0 && first_row + row < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot[row]; - } - } -} - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32( - global void * src0, - ulong offset0, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); -} - -// -// This variant unrolls the loops and uses vector types instead of pointers. -// It improves performance on Adreno but not so much on Intel. -// -inline float block_q_4_0_dot_y_v( - global struct block_q4_0 * qb_curr, - float sumy, - float16 yl, - int il -) { - float d = qb_curr->d; - float acc = 0.f; - global ushort * qs = ((global ushort *)qb_curr + 1 + il/2); - - acc += yl.s0 * (qs[0] & 0x000F); - acc += yl.s1 * (qs[0] & 0x0F00); - acc += yl.s8 * (qs[0] & 0x00F0); - acc += yl.s9 * (qs[0] & 0xF000); - - acc += yl.s2 * (qs[1] & 0x000F); - acc += yl.s3 * (qs[1] & 0x0F00); - acc += yl.sa * (qs[1] & 0x00F0); - acc += yl.sb * (qs[1] & 0xF000); - - acc += yl.s4 * (qs[2] & 0x000F); - acc += yl.s5 * (qs[2] & 0x0F00); - acc += yl.sc * (qs[2] & 0x00F0); - acc += yl.sd * (qs[2] & 0xF000); - - acc += yl.s6 * (qs[3] & 0x000F); - acc += yl.s7 * (qs[3] & 0x0F00); - acc += yl.se * (qs[3] & 0x00F0); - acc += yl.sf * (qs[3] & 0xF000); - - return d * (sumy * -8.f + acc); -} - -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 1 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 -#elif defined (ADRENO_GPU) -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif - -inline void mul_vec_q_n_f32_v( - global void * src0, - global float * src1, - global float * dst, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - const ulong nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global - // id of a SIMD group in the grid. - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float16 yl; // src1 vector cache - float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); - - int ix = get_sub_group_local_id()/2; - int il = 8*(get_sub_group_local_id()%2); - - global float * yb = y + ix * QK4_0 + il; - - // each thread in a SIMD group deals with half a block. - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { - float sumy = 0; - - sumy += yb[0]; - sumy += yb[1]; - sumy += yb[2]; - sumy += yb[3]; - sumy += yb[4]; - sumy += yb[5]; - sumy += yb[6]; - sumy += yb[7]; - - sumy += yb[16]; - sumy += yb[17]; - sumy += yb[18]; - sumy += yb[19]; - sumy += yb[20]; - sumy += yb[21]; - sumy += yb[22]; - sumy += yb[23]; - - - yl.s0 = yb[0]; - yl.s1 = yb[1]/256.f; - - yl.s2 = yb[2]; - yl.s3 = yb[3]/256.f; - - yl.s4 = yb[4]; - yl.s5 = yb[5]/256.f; - - yl.s6 = yb[6]; - yl.s7 = yb[7]/256.f; - - yl.s8 = yb[16]/16.f; - yl.s9 = yb[17]/4096.f; - - yl.sa = yb[18]/16.f; - yl.sb = yb[19]/4096.f; - - yl.sc = yb[20]/16.f; - yl.sd = yb[21]/4096.f; - - yl.se = yb[22]/16.f; - yl.sf = yb[23]/4096.f; - - sumf.s0 += block_q_4_0_dot_y_v(x+ib+0*nb, sumy, yl, il); - sumf.s1 += block_q_4_0_dot_y_v(x+ib+1*nb, sumy, yl, il); - sumf.s2 += block_q_4_0_dot_y_v(x+ib+2*nb, sumy, yl, il); - sumf.s3 += block_q_4_0_dot_y_v(x+ib+3*nb, sumy, yl, il); - - // One thread in a SIMD group (i.e., subgroup) handles a half block, - // hence then entire SIMD group handles SIMDWIDTH/2 blocks. - // y points to the activation matrix (of type float). Therefore for - // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because - // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of - // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size. - yb += QK4_0 * (N_SIMDWIDTH/2); - } - - // The above does not work for Adreno - it produces incorrect results for - // row = 1, 2, 3 and only row = 0 gives the correct result. - // If N_DST is changed, the below array must be initialized accordingly. - // This also seems to perform better on Intel. - float4 tot = (float4)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), - sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } - } -} - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_v( - global void * src0, - ulong offset0, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - mul_vec_q_n_f32_v(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); -} - -//------------------------------------------------------------------------------ -// kernel_convert_block_q4_0 -// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA). -// This kernel does not deshuffle the bits. -//------------------------------------------------------------------------------ -kernel void kernel_convert_block_q4_0( - global struct block_q4_0 * src0, - global uchar * dst_q, - global half * dst_d -) { - global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0); - global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0); - global half * d = (global half *) dst_d + get_global_id(0); - - *d = b->d; - - for (int i = 0; i < QK4_0/2; ++i) { - q[i] = b->qs[i]; - } -} - -kernel void kernel_restore_block_q4_0( - global uchar * src_q, - global half * src_d, - global struct block_q4_0 * dst -) { - global struct block_q4_0 * b = (global struct block_q4_0 *) dst + get_global_id(0); - global uchar * q = (global uchar *) src_q + QK4_0/2*get_global_id(0); - global half * d = (global half *) src_d + get_global_id(0); - - b->d = *d; - for (int i = 0; i < QK4_0/2; ++i) { - b->qs[i] = q[i]; - } -} - -//------------------------------------------------------------------------------ -// mul_vec_q_n_f32_flat -// -// This variation uses flat arrays (struct of arrays, SOA) representation for -// quant tensors. -//------------------------------------------------------------------------------ - -// This function requires the original shuffled weights. -// As a reminder, the original weights are shuffled so that (q[0], q[16]) are -// packed together in a byte, so are (q[1], q[17]) and so on. -inline float block_q_4_0_dot_y_flat( - global uchar * x, - global half * dh, - float sumy, - float16 yl, - int il -) { - float d = *dh; - global ushort * qs = ((global ushort *)x + il/2); - float acc = 0.f; - - acc += yl.s0 * (qs[0] & 0x000F); - acc += yl.s1 * (qs[0] & 0x0F00); - acc += yl.s8 * (qs[0] & 0x00F0); - acc += yl.s9 * (qs[0] & 0xF000); - - acc += yl.s2 * (qs[1] & 0x000F); - acc += yl.s3 * (qs[1] & 0x0F00); - acc += yl.sa * (qs[1] & 0x00F0); - acc += yl.sb * (qs[1] & 0xF000); - - acc += yl.s4 * (qs[2] & 0x000F); - acc += yl.s5 * (qs[2] & 0x0F00); - acc += yl.sc * (qs[2] & 0x00F0); - acc += yl.sd * (qs[2] & 0xF000); - - acc += yl.s6 * (qs[3] & 0x000F); - acc += yl.s7 * (qs[3] & 0x0F00); - acc += yl.se * (qs[3] & 0x00F0); - acc += yl.sf * (qs[3] & 0xF000); - - return d * (sumy * -8.f + acc); -} - -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 1 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // assuming SIMD group size is 32 -#elif defined (ADRENO_GPU) -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif - -inline void mul_vec_q_n_f32_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - global float * dst, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - const ulong nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of - // a SIMD group in the grid. Each SIMD group produces N_DST values in the - // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. - // Currently with llama2 7B, im is always 0. - // TODO: how to handle im/gqa*(nb*ne0)? - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - // The number of scales is the same as the number of blocks. - ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. - ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; - - global uchar * x = (global uchar *) src0_q + offset0_q; - global half * d = (global half *) src0_d + offset0_d; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float16 yl; - float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); - - int ix = get_sub_group_local_id()/2; - int il = 8*(get_sub_group_local_id()%2); - - global float * yb = y + ix*QK4_0 + il; - - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { - float sumy = 0.f; - - sumy += yb[0]; - sumy += yb[1]; - sumy += yb[2]; - sumy += yb[3]; - sumy += yb[4]; - sumy += yb[5]; - sumy += yb[6]; - sumy += yb[7]; - - sumy += yb[16]; - sumy += yb[17]; - sumy += yb[18]; - sumy += yb[19]; - sumy += yb[20]; - sumy += yb[21]; - sumy += yb[22]; - sumy += yb[23]; - - yl.s0 = yb[0]; - yl.s1 = yb[1]/256.f; - - yl.s2 = yb[2]; - yl.s3 = yb[3]/256.f; - - yl.s4 = yb[4]; - yl.s5 = yb[5]/256.f; - - yl.s6 = yb[6]; - yl.s7 = yb[7]/256.f; - - yl.s8 = yb[16]/16.f; - yl.s9 = yb[17]/4096.f; - - yl.sa = yb[18]/16.f; - yl.sb = yb[19]/4096.f; - - yl.sc = yb[20]/16.f; - yl.sd = yb[21]/4096.f; - - yl.se = yb[22]/16.f; - yl.sf = yb[23]/4096.f; - - sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); - sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); - sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); - sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); - - yb += QK4_0 * (N_SIMDWIDTH/2); - } - - float4 tot = (float4)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), - sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } - } -} - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - mul_vec_q_n_f32_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); -} - -// -// This variant outputs 8 values. -// -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define N_DST 8 // each SIMD group works on 8 rows -#define N_SIMDGROUP 1 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // assuming SIMD group size is 32 -#elif defined (ADRENO_GPU) -#define N_DST 8 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif - -inline void mul_vec_q_n_f32_8x_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - global float * dst, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - const ulong nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of - // a SIMD group in the grid. Each SIMD group produces N_DST values in the - // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. - // Currently with llama2 7B, im is always 0. - // TODO: how to handle im/gqa*(nb*ne0)? - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - // The number of scales is the same as the number of blocks. - ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. - ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; - - global uchar * x = (global uchar *) src0_q + offset0_q; - global half * d = (global half *) src0_d + offset0_d; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float16 yl; - float8 sumf = 0.f; - - int ix = get_sub_group_local_id()/2; - int il = 8*(get_sub_group_local_id()%2); - - global float * yb = y + ix*QK4_0 + il; - - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { - float sumy = 0.f; - - sumy += yb[0]; - sumy += yb[1]; - sumy += yb[2]; - sumy += yb[3]; - sumy += yb[4]; - sumy += yb[5]; - sumy += yb[6]; - sumy += yb[7]; - - sumy += yb[16]; - sumy += yb[17]; - sumy += yb[18]; - sumy += yb[19]; - sumy += yb[20]; - sumy += yb[21]; - sumy += yb[22]; - sumy += yb[23]; - - yl.s0 = yb[0]; - yl.s1 = yb[1]/256.f; - - yl.s2 = yb[2]; - yl.s3 = yb[3]/256.f; - - yl.s4 = yb[4]; - yl.s5 = yb[5]/256.f; - - yl.s6 = yb[6]; - yl.s7 = yb[7]/256.f; - - yl.s8 = yb[16]/16.f; - yl.s9 = yb[17]/4096.f; - - yl.sa = yb[18]/16.f; - yl.sb = yb[19]/4096.f; - - yl.sc = yb[20]/16.f; - yl.sd = yb[21]/4096.f; - - yl.se = yb[22]/16.f; - yl.sf = yb[23]/4096.f; - - sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); - sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); - sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); - sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); - - sumf.s4 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); - sumf.s5 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); - sumf.s6 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); - sumf.s7 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); - - yb += QK4_0 * (N_SIMDWIDTH/2); - } - - float8 tot = (float8)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), - sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), - sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), - sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } - - if (first_row + 4 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; - } - if (first_row + 5 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; - } - if (first_row + 6 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; - } - if (first_row + 7 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; - } - } -} - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_8x_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - mul_vec_q_n_f32_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); -} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_im2col.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_im2col.cl deleted file mode 100644 index 9b41dfb25..000000000 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_im2col.cl +++ /dev/null @@ -1,146 +0,0 @@ -#ifdef cl_khr_fp16 -#pragma OPENCL EXTENSION cl_khr_fp16 : enable -#elif defined(cl_amd_fp16) -#pragma OPENCL EXTENSION cl_amd_fp16 : enable -#else -#error "Half precision floating point not supportedby OpenCL implementation on your device." -#endif - -#ifdef cl_khr_subgroups -#pragma OPENCL EXTENSION cl_khr_subgroups : enable -#elif defined(cl_intel_subgroups) -#pragma OPENCL EXTENSION cl_intel_subgroups : enable -#else -#error "Subgroup not supported on your device." -#endif - -#ifdef cl_intel_required_subgroup_size -// Always use subgroup size of 32 on Intel. -#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable -#define INTEL_GPU 1 -#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) -#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) -#elif defined(cl_qcom_reqd_sub_group_size) -// Always use subgroups size of 64 on Adreno. -#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable -#define ADRENO_GPU 1 -#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) -#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) -#else -// TODO: do not know how to choose subgroup size on other GPUs. -#error "Selecting subgroup size is not supported on your device." -#endif - -kernel void kernel_im2col_f32( - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - ulong batch_offset, - ulong delta_offset, - long IW, - long IH, - long IC, - long OW, - long OH, - long KW, - long KH, - long pelements, - long CHW, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1 -) { - // threadIdx.x + blockIdx.x * blockDim.x - long i = get_global_id(0); - if (i >= pelements) { - return; - } - - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - long ksize = OW * (KH > 1 ? KW : 1); - long kx = i / ksize; - long kd = kx * ksize; - long ky = (i - kd) / OW; - long ix = i % OW; - - long oh = get_group_id(1); - long batch = get_group_id(2) / IC; - long ic = get_group_id(2) % IC; - - long iiw = ix * s0 + kx * d0 - p0; - long iih = oh * s1 + ky * d1 - p1; - - long offset_dst = - ((batch * OH + oh) * OW + ix) * CHW + - (ic * (KW * KH) + ky * KW + kx); - - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = 0.0f; - } else { - long offset_src = ic * delta_offset + batch * batch_offset; - dst[offset_dst] = src1[offset_src + iih * IW + iiw]; - } -} - -kernel void kernel_im2col_f16( - global float * src1, - ulong offset1, - global half * dst, - ulong offsetd, - ulong batch_offset, - ulong delta_offset, - long IW, - long IH, - long IC, - long OW, - long OH, - long KW, - long KH, - long pelements, - long CHW, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1 -) { - long i = get_global_id(0); - - if (i >= pelements) { - return; - } - - src1 = (global float*)((global char*)src1 + offset1); - dst = (global half*)((global char*)dst + offsetd); - - long ksize = OW * (KH > 1 ? KW : 1); - long kx = i / ksize; - long kd = kx * ksize; - long ky = (i - kd) / OW; - long ix = i % OW; - - long oh = get_group_id(1); - long batch = get_group_id(2) / IC; - long ic = get_group_id(2) % IC; - - long iiw = ix * s0 + kx * d0 - p0; - long iih = oh * s1 + ky * d1 - p1; - - long offset_dst = - ((batch * OH + oh) * OW + ix) * CHW + - (ic * (KW * KH) + ky * KW + kx); - - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = 0.0f; - } else { - long offset_src = ic * delta_offset + batch * batch_offset; - dst[offset_dst] = src1[offset_src + iih * IW + iiw]; - } -} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl deleted file mode 100644 index e19e9a2f4..000000000 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl +++ /dev/null @@ -1,1225 +0,0 @@ -//------------------------------------------------------------------------------ -// This file is contains additional mulmat kernels -// (and potentially other kernels). -//------------------------------------------------------------------------------ -#ifdef cl_khr_fp16 -#pragma OPENCL EXTENSION cl_khr_fp16 : enable -#elif defined(cl_amd_fp16) -#pragma OPENCL EXTENSION cl_amd_fp16 : enable -#else -#error "Half precision floating point not supportedby OpenCL implementation on your device." -#endif - -#ifdef cl_khr_subgroups -#pragma OPENCL EXTENSION cl_khr_subgroups : enable -#elif defined(cl_intel_subgroups) -#pragma OPENCL EXTENSION cl_intel_subgroups : enable -#else -#error "Subgroup not supported on your device." -#endif - -#ifdef cl_intel_required_subgroup_size -// Always use subgroup size of 32 on Intel. -#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable -#define INTEL_GPU 1 -#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) -#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) -#elif defined(cl_qcom_reqd_sub_group_size) -// Always use subgroups size of 64 on Adreno. -#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable -#define ADRENO_GPU 1 -#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) -#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) -#else -// TODO: do not know how to choose subgroup size on other GPUs. -#error "Selecting subgroup size is not supported on your device." -#endif - -#define QK4_0 32 -#define QR4_0 2 -#define QK4_1 32 -#define QR4_1 2 -#define QK5_0 32 -#define QR5_0 2 -#define QK5_1 32 -#define QR5_1 2 -#define QK8_0 32 -#define QR8_0 1 -#define QK_K 256 -#define K_QUANTS_PER_ITERATION 2 - -typedef char int8_t; -typedef uchar uint8_t; -typedef short int16_t; -typedef ushort uint16_t; -typedef int int32_t; -typedef uint uint32_t; - -//------------------------------------------------------------------------------ -// block_q4_0 -//------------------------------------------------------------------------------ -struct block_q4_0 -{ - half d; - uint8_t qs[QK4_0 / 2]; -}; - -//------------------------------------------------------------------------------ -// block_q6_K -//------------------------------------------------------------------------------ -// 6-bit quantization -// weight is represented as x = a * q -// 16 blocks of 16 elements each -// Effectively 6.5625 bits per weight -typedef struct { - uint8_t ql[QK_K/2]; // quants, lower 4 bits - uint8_t qh[QK_K/4]; // quants, upper 2 bits - int8_t scales[QK_K/16]; // scales, quantized with 8 bits - half d; // super-block scale -} block_q6_K; - -//------------------------------------------------------------------------------ -// These are the variant for matmatmul, based on the matvecmul kernel with -// flattened block_q4_0. -//------------------------------------------------------------------------------ - -// Common dot prod. -inline float mm_block_q_4_0_dot_y_flat( - global uchar * x, - global half * dh, - float sumy, - float16 yl, - int il -) { - float d = *dh; - global ushort * qs = ((global ushort *)x + il/2); - float acc = 0.f; - - acc += yl.s0 * (qs[0] & 0x000F); - acc += yl.s1 * (qs[0] & 0x0F00); - acc += yl.s8 * (qs[0] & 0x00F0); - acc += yl.s9 * (qs[0] & 0xF000); - - acc += yl.s2 * (qs[1] & 0x000F); - acc += yl.s3 * (qs[1] & 0x0F00); - acc += yl.sa * (qs[1] & 0x00F0); - acc += yl.sb * (qs[1] & 0xF000); - - acc += yl.s4 * (qs[2] & 0x000F); - acc += yl.s5 * (qs[2] & 0x0F00); - acc += yl.sc * (qs[2] & 0x00F0); - acc += yl.sd * (qs[2] & 0xF000); - - acc += yl.s6 * (qs[3] & 0x000F); - acc += yl.s7 * (qs[3] & 0x0F00); - acc += yl.se * (qs[3] & 0x00F0); - acc += yl.sf * (qs[3] & 0xF000); - - return d * (sumy * -8.f + acc); -} - -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define N_DST 8 // each SIMD group works on 8 rows (in weights matrix) -#define N_SIMDGROUP 1 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 -#elif defined (ADRENO_GPU) -#define N_DST 8 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif -// -// This variant performs 1d blocking with 8x output. -// Eeach simdgroup outputs 8 values on `n0` dim (row in the output matrix). -// -inline void mul_mat_q_n_f32_1d_8x_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - global float * dst, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - const int nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of - // a SIMD group in the grid. Each SIMD group produces N_DST values in the - // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. - // Currently with llama2 7B, im is always 0. - // TODO: how to handle im/gqa*(nb*ne0)? - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - // The number of scales is the same as the number of blocks. - ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. - ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; - - global uchar * x = (global uchar *) src0_q + offset0_q; - global half * d = (global half *) src0_d + offset0_d; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float16 yl; - float8 sumf = (float8)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f); - - int ix = get_sub_group_local_id()/2; - int il = 8*(get_sub_group_local_id()%2); - - global float * yb = y + ix*QK4_0 + il; - - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { - float sumy = 0.f; - - sumy += yb[0]; - sumy += yb[1]; - sumy += yb[2]; - sumy += yb[3]; - sumy += yb[4]; - sumy += yb[5]; - sumy += yb[6]; - sumy += yb[7]; - - sumy += yb[16]; - sumy += yb[17]; - sumy += yb[18]; - sumy += yb[19]; - sumy += yb[20]; - sumy += yb[21]; - sumy += yb[22]; - sumy += yb[23]; - - yl.s0 = yb[0]; - yl.s1 = yb[1]/256.f; - - yl.s2 = yb[2]; - yl.s3 = yb[3]/256.f; - - yl.s4 = yb[4]; - yl.s5 = yb[5]/256.f; - - yl.s6 = yb[6]; - yl.s7 = yb[7]/256.f; - - yl.s8 = yb[16]/16.f; - yl.s9 = yb[17]/4096.f; - - yl.sa = yb[18]/16.f; - yl.sb = yb[19]/4096.f; - - yl.sc = yb[20]/16.f; - yl.sd = yb[21]/4096.f; - - yl.se = yb[22]/16.f; - yl.sf = yb[23]/4096.f; - - sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); - sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); - sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); - sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); - - sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); - sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); - sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); - sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); - - yb += QK4_0 * (N_SIMDWIDTH/2); - } - - float8 tot = (float8)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), - sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), - sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), - sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } - - if (first_row + 4 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; - } - if (first_row + 5 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; - } - if (first_row + 6 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; - } - if (first_row + 7 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; - } - } -} - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_1d_8x_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - mul_mat_q_n_f32_1d_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); -} - -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define N_DST 16 // each SIMD group works on 8 rows (in weights matrix) -#define N_SIMDGROUP 1 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 -#elif defined (ADRENO_GPU) -#define N_DST 16 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif -// -// This variant performs 1d blocking with 16x output. -// Eeach simdgroup outputs 16 values on `n0` dim (row in the output matrix). -// -inline void mul_mat_q_n_f32_1d_16x_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - global float * dst, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - const int nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of - // a SIMD group in the grid. Each SIMD group produces N_DST values in the - // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. - // Currently with llama2 7B, im is always 0. - // TODO: how to handle im/gqa*(nb*ne0)? - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - // The number of scales is the same as the number of blocks. - ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. - ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; - - global uchar * x = (global uchar *) src0_q + offset0_q; - global half * d = (global half *) src0_d + offset0_d; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float16 yl; - float16 sumf = (float16)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f); - - int ix = get_sub_group_local_id()/2; - int il = 8*(get_sub_group_local_id()%2); - - global float * yb = y + ix*QK4_0 + il; - - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { - float sumy = 0.f; - - sumy += yb[0]; - sumy += yb[1]; - sumy += yb[2]; - sumy += yb[3]; - sumy += yb[4]; - sumy += yb[5]; - sumy += yb[6]; - sumy += yb[7]; - - sumy += yb[16]; - sumy += yb[17]; - sumy += yb[18]; - sumy += yb[19]; - sumy += yb[20]; - sumy += yb[21]; - sumy += yb[22]; - sumy += yb[23]; - - yl.s0 = yb[0]; - yl.s1 = yb[1]/256.f; - - yl.s2 = yb[2]; - yl.s3 = yb[3]/256.f; - - yl.s4 = yb[4]; - yl.s5 = yb[5]/256.f; - - yl.s6 = yb[6]; - yl.s7 = yb[7]/256.f; - - yl.s8 = yb[16]/16.f; - yl.s9 = yb[17]/4096.f; - - yl.sa = yb[18]/16.f; - yl.sb = yb[19]/4096.f; - - yl.sc = yb[20]/16.f; - yl.sd = yb[21]/4096.f; - - yl.se = yb[22]/16.f; - yl.sf = yb[23]/4096.f; - - sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); - sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); - sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); - sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); - - sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); - sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); - sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); - sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); - - sumf.s8 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 8*nb*QK4_0/2, d + ib + 8*nb, sumy, yl, il); - sumf.s9 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 9*nb*QK4_0/2, d + ib + 9*nb, sumy, yl, il); - sumf.sa += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 10*nb*QK4_0/2, d + ib + 10*nb, sumy, yl, il); - sumf.sb += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 11*nb*QK4_0/2, d + ib + 11*nb, sumy, yl, il); - - sumf.sc += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 12*nb*QK4_0/2, d + ib + 12*nb, sumy, yl, il); - sumf.sd += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 13*nb*QK4_0/2, d + ib + 13*nb, sumy, yl, il); - sumf.se += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 14*nb*QK4_0/2, d + ib + 14*nb, sumy, yl, il); - sumf.sf += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 15*nb*QK4_0/2, d + ib + 15*nb, sumy, yl, il); - - yb += QK4_0 * (N_SIMDWIDTH/2); - } - - float16 tot = (float16)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), - sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), - sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), - sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7), - - sub_group_reduce_add(sumf.s8), sub_group_reduce_add(sumf.s9), - sub_group_reduce_add(sumf.sa), sub_group_reduce_add(sumf.sb), - sub_group_reduce_add(sumf.sc), sub_group_reduce_add(sumf.sd), - sub_group_reduce_add(sumf.se), sub_group_reduce_add(sumf.sf) - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } - - if (first_row + 4 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; - } - if (first_row + 5 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; - } - if (first_row + 6 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; - } - if (first_row + 7 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; - } - - if (first_row + 8 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 8] = tot.s8; - } - if (first_row + 9 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 9] = tot.s9; - } - if (first_row + 10 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 10] = tot.sa; - } - if (first_row + 11 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 11] = tot.sb; - } - - if (first_row + 12 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 12] = tot.sc; - } - if (first_row + 13 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 13] = tot.sd; - } - if (first_row + 14 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 14] = tot.se; - } - if (first_row + 15 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 15] = tot.sf; - } - } -} - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_1d_16x_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - mul_mat_q_n_f32_1d_16x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); -} - -//------------------------------------------------------------------------------ -// kernel_mul_mat_q4_0_f32_flat_v0 -//------------------------------------------------------------------------------ -inline float block_q_4_0_dot_y_flat_v2( - half x, - half d, - float sumy, - float4 yl -) { - uchar2 q = as_uchar2(x); - float acc = 0.0f; - - acc += (q.s0 & 0x0F) * yl.s0; - acc += (q.s1 & 0x0F) * yl.s1; - - acc += (q.s0 & 0xF0) * yl.s2; - acc += (q.s1 & 0xF0) * yl.s3; - - return d * (sumy * -8.f + acc);; -} - -inline float block_q_4_0_dot_y_flat_v4( - float x, - half d, - float sumy, - float8 yl -) { - uchar4 q = as_uchar4(x); - float acc = 0.0f; - - acc += (q.s0 & 0x0F) * yl.s0; - acc += (q.s1 & 0x0F) * yl.s1; - acc += (q.s2 & 0x0F) * yl.s2; - acc += (q.s3 & 0x0F) * yl.s3; - - acc += (q.s0 & 0xF0) * yl.s4; - acc += (q.s1 & 0xF0) * yl.s5; - acc += (q.s2 & 0xF0) * yl.s6; - acc += (q.s3 & 0xF0) * yl.s7; - - return d * (sumy * -8.f + acc);; -} - -inline float block_q_4_0_dot_y_flat_v8( - float2 x, - half d, - float sumy, - float16 yl -) { - uchar8 q = as_uchar8(x); - float acc = 0.0f; - - acc += (q.s0 & 0x0F) * yl.s0; - acc += (q.s1 & 0x0F) * yl.s1; - acc += (q.s2 & 0x0F) * yl.s2; - acc += (q.s3 & 0x0F) * yl.s3; - acc += (q.s4 & 0x0F) * yl.s4; - acc += (q.s5 & 0x0F) * yl.s5; - acc += (q.s6 & 0x0F) * yl.s6; - acc += (q.s7 & 0x0F) * yl.s7; - - acc += (q.s0 & 0xF0) * yl.s8; - acc += (q.s1 & 0xF0) * yl.s9; - acc += (q.s2 & 0xF0) * yl.sa; - acc += (q.s3 & 0xF0) * yl.sb; - acc += (q.s4 & 0xF0) * yl.sc; - acc += (q.s5 & 0xF0) * yl.sd; - acc += (q.s6 & 0xF0) * yl.se; - acc += (q.s7 & 0xF0) * yl.sf; - - return d * (sumy * -8.f + acc);; -} - -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define THREADS_PER_BLK 4 // Number of threads per block, or each thread process 1/THREADS_PER_BLK of a block -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 16 -#elif defined (ADRENO_GPU) -#define THREADS_PER_BLK 4 -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif - -#if THREADS_PER_BLK == 2 // Each thread processes 1/2 block -# define ACT_TY float16 -# define Q_BLK_LD_TY float2 -# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v8 -#elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block -# define ACT_TY float8 -# define Q_BLK_LD_TY float -# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v4 -#elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block -# define ACT_TY float4 -# define Q_BLK_LD_TY half -# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v2 -#endif - -#define BTYES_PER_THREAD_IN_BLK (QK4_0/2/THREADS_PER_BLK) - -#if N_DST == 2 -# define SUM_TY float2 -#elif N_DST == 4 -# define SUM_TY float4 -#elif N_DST == 8 -# define SUM_TY float8 -#elif N_DST == 16 -# define SUM_TY float16 -#endif - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_flat_v0( - global uchar * src0_q, - global half * src0_d, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - const int nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - // The number of scales is the same as the number of blocks. - ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. - ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; - - global uchar * x = (global uchar *) src0_q + offset0_q; - global half * d = (global half *) src0_d + offset0_d; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - int ix = get_sub_group_local_id()/THREADS_PER_BLK; - int il = get_sub_group_local_id()%THREADS_PER_BLK; - - global float * yb = y + ix*QK4_0 + BTYES_PER_THREAD_IN_BLK*il; - - // Registers for caching activation - ACT_TY yl = 0.f; - - // Registers for caching quants - Q_BLK_LD_TY q_blk_0 = 0, q_blk_1 = 0; -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - Q_BLK_LD_TY q_blk_2 = 0, q_blk_3 = 0; -#endif -#if N_DST == 8 || N_DST == 16 - Q_BLK_LD_TY q_blk_4 = 0, q_blk_5 = 0, q_blk_6 = 0, q_blk_7 = 0; -#endif - - // Partial sum - SUM_TY sumf = 0.f; - - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/THREADS_PER_BLK) { - float sumy = 0.f; - - q_blk_0 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 0*nb*QK4_0/2); - q_blk_1 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 1*nb*QK4_0/2); -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - q_blk_2 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 2*nb*QK4_0/2); - q_blk_3 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 3*nb*QK4_0/2); -#endif -#if N_DST == 8 || N_DST == 16 - q_blk_4 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 4*nb*QK4_0/2)); - q_blk_5 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 5*nb*QK4_0/2)); - q_blk_6 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 6*nb*QK4_0/2)); - q_blk_7 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 7*nb*QK4_0/2)); -#endif - - // Load activation -#if THREADS_PER_BLK == 2 // Each thread processes 1/2 block - yl.s01234567 = *(global float8 *)(yb); - yl.s89abcdef = *(global float8 *)(yb + 16); - - sumy += yl.s0; - sumy += yl.s1; - sumy += yl.s2; - sumy += yl.s3; - sumy += yl.s4; - sumy += yl.s5; - sumy += yl.s6; - sumy += yl.s7; - sumy += yl.s8; yl.s8 /= 16.f; - sumy += yl.s9; yl.s9 /= 16.f; - sumy += yl.sa; yl.sa /= 16.f; - sumy += yl.sb; yl.sb /= 16.f; - sumy += yl.sc; yl.sc /= 16.f; - sumy += yl.sd; yl.sd /= 16.f; - sumy += yl.se; yl.se /= 16.f; - sumy += yl.sf; yl.sf /= 16.f; -#elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block - yl.s0123 = *(global float4 *)(yb); - yl.s4567 = *(global float4 *)(yb + 16); - - sumy += yl.s0; - sumy += yl.s1; - sumy += yl.s2; - sumy += yl.s3; - sumy += yl.s4; yl.s4 /= 16.f; - sumy += yl.s5; yl.s5 /= 16.f; - sumy += yl.s6; yl.s6 /= 16.f; - sumy += yl.s7; yl.s7 /= 16.f; -#elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block - yl.s01 = *(global float2 *)(yb); - yl.s23 = *(global float2 *)(yb + 16); - - sumy += yl.s0; - sumy += yl.s1; - sumy += yl.s2; yl.s2 /= 16.f; - sumy += yl.s3; yl.s3 /= 16.f; -#endif - - sumf.s0 += block_q_4_0_dot_y_flat(q_blk_0, *(d + ib + 0*nb), sumy, yl); - sumf.s1 += block_q_4_0_dot_y_flat(q_blk_1, *(d + ib + 1*nb), sumy, yl); -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - sumf.s2 += block_q_4_0_dot_y_flat(q_blk_2, *(d + ib + 2*nb), sumy, yl); - sumf.s3 += block_q_4_0_dot_y_flat(q_blk_3, *(d + ib + 3*nb), sumy, yl); -#endif -#if N_DST == 8 || N_DST == 16 - sumf.s4 += block_q_4_0_dot_y_flat(q_blk_4, *(d + ib + 4*nb), sumy, yl); - sumf.s5 += block_q_4_0_dot_y_flat(q_blk_5, *(d + ib + 5*nb), sumy, yl); - sumf.s6 += block_q_4_0_dot_y_flat(q_blk_6, *(d + ib + 6*nb), sumy, yl); - sumf.s7 += block_q_4_0_dot_y_flat(q_blk_7, *(d + ib + 7*nb), sumy, yl); -#endif - - yb += QK4_0 * (N_SIMDWIDTH/THREADS_PER_BLK); - } - - SUM_TY tot = (SUM_TY)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1) -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - , sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) -#endif -#if N_DST == 8 || N_DST == 16 - , sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5) - , sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) -#endif - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } -#endif -#if N_DST == 8 || N_DST == 16 - if (first_row + 4 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; - } - if (first_row + 5 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; - } - if (first_row + 6 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; - } - if (first_row + 7 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; - } -#endif - } -} - -//------------------------------------------------------------------------------ -// Using image1d_buffer_t - -#if defined(cl_qcom_subgroup_shuffle) -#pragma OPENCL EXTENSION cl_qcom_subgroup_shuffle : enable -float qcom_sub_group_reduce_add(float sum) { - sum += qcom_sub_group_shuffle_down(sum, 32, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); - sum += qcom_sub_group_shuffle_down(sum, 16, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); - sum += qcom_sub_group_shuffle_down(sum, 8, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); - sum += qcom_sub_group_shuffle_down(sum, 4, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); - sum += qcom_sub_group_shuffle_down(sum, 2, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); - sum += qcom_sub_group_shuffle_down(sum, 1, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); - return sum; -} -#define sub_group_reduce_add qcom_sub_group_reduce_add -#else -#define sub_group_reduce_add sub_group_reduce_add -#endif - -#undef THREADS_PER_BLK -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define THREADS_PER_BLK 4 // Number of threads per block, or each thread process 1/THREADS_PER_BLK of a block -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 16 -#elif defined (ADRENO_GPU) -#define THREADS_PER_BLK 4 -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif - -#if THREADS_PER_BLK == 2 // Each thread processes 1/2 block -# define ACT_TY float16 -# define Q_BLK_LD_TY float2 -# define EXTRACT_BLK_DATA(tmp, part) *((float2*)&tmp + part) -# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v8 -#elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block -# define ACT_TY float8 -# define Q_BLK_LD_TY float -# define EXTRACT_BLK_DATA(tmp, part) *((float*)&tmp + part) -# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v4 -#elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block -# define ACT_TY float4 -# define Q_BLK_LD_TY half -# define EXTRACT_BLK_DATA(tmp, part) *((half*)&tmp + part) -# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v2 -#endif - -#define BTYES_PER_THREAD_IN_BLK (QK4_0/2/THREADS_PER_BLK) - -#if N_DST == 2 -# define SUM_TY float2 -#elif N_DST == 4 -# define SUM_TY float4 -#elif N_DST == 8 -# define SUM_TY float8 -#elif N_DST == 16 -# define SUM_TY float16 -#endif - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_flat_img_v0( - read_only image1d_buffer_t src0_q, - read_only image1d_buffer_t src0_d, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - const int nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - // The number of scales is the same as the number of blocks. - ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. - ulong offset0_q = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - int ix = get_sub_group_local_id()/THREADS_PER_BLK; - int il = get_sub_group_local_id()%THREADS_PER_BLK; - - global float * yb = y + ix*QK4_0 + BTYES_PER_THREAD_IN_BLK*il; - - // Registers for caching activation - ACT_TY yl = 0.f; - - // Registers for caching quants - Q_BLK_LD_TY q_blk_0 = 0, q_blk_1 = 0; -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - Q_BLK_LD_TY q_blk_2 = 0, q_blk_3 = 0; -#endif -#if N_DST == 8 || N_DST == 16 - Q_BLK_LD_TY q_blk_4 = 0, q_blk_5 = 0, q_blk_6 = 0, q_blk_7 = 0; -#endif - - // Partial sum - SUM_TY sumf = 0.f; - - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/THREADS_PER_BLK) { - float sumy = 0.f;; - - float4 tmp; - tmp = read_imagef(src0_q, offset0_q + ib + 0*nb); - q_blk_0 = EXTRACT_BLK_DATA(tmp, il); - tmp = read_imagef(src0_q, offset0_q + ib + 1*nb); - q_blk_1 = EXTRACT_BLK_DATA(tmp, il); -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - tmp = read_imagef(src0_q, offset0_q + ib + 2*nb); - q_blk_2 = EXTRACT_BLK_DATA(tmp, il); - tmp = read_imagef(src0_q, offset0_q + ib + 3*nb); - q_blk_3 = EXTRACT_BLK_DATA(tmp, il); -#endif -#if N_DST == 8 || N_DST == 16 - tmp = read_imagef(src0_q, offset0_q + ib + 4*nb); - q_blk_4 = EXTRACT_BLK_DATA(tmp, il); - tmp = read_imagef(src0_q, offset0_q + ib + 5*nb); - q_blk_5 = EXTRACT_BLK_DATA(tmp, il); - tmp = read_imagef(src0_q, offset0_q + ib + 6*nb); - q_blk_6 = EXTRACT_BLK_DATA(tmp, il); - tmp = read_imagef(src0_q, offset0_q + ib + 7*nb); - q_blk_7 = EXTRACT_BLK_DATA(tmp, il); -#endif - - // Load activation -#if THREADS_PER_BLK == 2 // Each thread processes 1/2 block - yl.s01234567 = *(global float8 *)(yb); - yl.s89abcdef = *(global float8 *)(yb + 16); - - sumy += yl.s0; - sumy += yl.s1; - sumy += yl.s2; - sumy += yl.s3; - sumy += yl.s4; - sumy += yl.s5; - sumy += yl.s6; - sumy += yl.s7; - sumy += yl.s8; yl.s8 /= 16.f; - sumy += yl.s9; yl.s9 /= 16.f; - sumy += yl.sa; yl.sa /= 16.f; - sumy += yl.sb; yl.sb /= 16.f; - sumy += yl.sc; yl.sc /= 16.f; - sumy += yl.sd; yl.sd /= 16.f; - sumy += yl.se; yl.se /= 16.f; - sumy += yl.sf; yl.sf /= 16.f; -#elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block - yl.s0123 = *(global float4 *)(yb); - yl.s4567 = *(global float4 *)(yb + 16); - - sumy += yl.s0; - sumy += yl.s1; - sumy += yl.s2; - sumy += yl.s3; - sumy += yl.s4; yl.s4 /= 16.f; - sumy += yl.s5; yl.s5 /= 16.f; - sumy += yl.s6; yl.s6 /= 16.f; - sumy += yl.s7; yl.s7 /= 16.f; -#elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block - yl.s01 = *(global float2 *)(yb); - yl.s23 = *(global float2 *)(yb + 16); - - sumy += yl.s0; - sumy += yl.s1; - sumy += yl.s2; yl.s2 /= 16.f; - sumy += yl.s3; yl.s3 /= 16.f; -#endif - - sumf.s0 += block_q_4_0_dot_y_flat(q_blk_0, read_imageh(src0_d, offset0_d + ib + 0*nb).s0, sumy, yl); - sumf.s1 += block_q_4_0_dot_y_flat(q_blk_1, read_imageh(src0_d, offset0_d + ib + 1*nb).s0, sumy, yl); -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - sumf.s2 += block_q_4_0_dot_y_flat(q_blk_2, read_imageh(src0_d, offset0_d + ib + 2*nb).s0, sumy, yl); - sumf.s3 += block_q_4_0_dot_y_flat(q_blk_3, read_imageh(src0_d, offset0_d + ib + 3*nb).s0, sumy, yl); -#endif -#if N_DST == 8 || N_DST == 16 - sumf.s4 += block_q_4_0_dot_y_flat(q_blk_4, read_imageh(src0_d, offset0_d + ib + 4*nb).s0, sumy, yl); - sumf.s5 += block_q_4_0_dot_y_flat(q_blk_5, read_imageh(src0_d, offset0_d + ib + 5*nb).s0, sumy, yl); - sumf.s6 += block_q_4_0_dot_y_flat(q_blk_6, read_imageh(src0_d, offset0_d + ib + 6*nb).s0, sumy, yl); - sumf.s7 += block_q_4_0_dot_y_flat(q_blk_7, read_imageh(src0_d, offset0_d + ib + 7*nb).s0, sumy, yl); -#endif - - yb += QK4_0 * (N_SIMDWIDTH/THREADS_PER_BLK); - } - - SUM_TY tot = (SUM_TY)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1) -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - , sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) -#endif -#if N_DST == 8 || N_DST == 16 - , sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5) - , sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) -#endif - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } -#endif -#if N_DST == 8 || N_DST == 16 - if (first_row + 4 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; - } - if (first_row + 5 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; - } - if (first_row + 6 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; - } - if (first_row + 7 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; - } -#endif - } -} - -//------------------------------------------------------------------------------ -// kernel_mul_mv_q6_K_f32 -//------------------------------------------------------------------------------ - -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define N_DST 1 // number of rows each SIMD group works on -#define N_SIMDGROUP 2 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // SIMD group size -#elif defined (ADRENO_GPU) -#define N_DST 1 -#define N_SIMDGROUP 2 -#define N_SIMDWIDTH 64 -#endif - -#define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mv_q6_K_f32( - global void * src0, - ulong offset0, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - uchar kmask1 = 0x03; - uchar kmask2 = 0x0C; - uchar kmask3 = 0x30; - uchar kmask4 = 0xC0; - - int nb = ne00/QK_K; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - int row = N_SIMDGROUP * r0 + get_sub_group_id(); - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset_src0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - global block_q6_K * x = (global block_q6_K *) src0 + row*nb + offset_src0; - global float * yy = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float sumf = 0; - - // For Q6_K quantization, 16 values forms a subblock, 16 subblock forms a - // block. Values in a subblock shares a scale that is quantized with 8 bits; - // the entire block shares a single floating point scale. - // For work distribution, each thread processes a subblock (16 weights), hence - // 16 threads process a (super) block -- a subgroup thus handles SIMDWIDTH/16 - // (super) blocks -- this is the block stride. - // The 16 threads that process a (super) block are split into 2 portions, each has - // 8 threads; each portion works on 8 subblocks. - // For subgroup of 16 threads, the entire subgroup works on a single (super) block - // before moving to the next (super) block. Thread0 - thread7 work on the - // first 8 subblocks; thread8 - thread15 works on the last 8 subblocks. - // Thread0 - thread3 work on subblocks 0, 2, 4, 6; thread4 - thread7 work on - // subblocks 1, 3, 5, 7. Each thread does not work on an entire subblock, but - // works on a total of 16 weight values. - int tid = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0 - int ix = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1 - int ip = tid/8; // first or second half of (super) block (0 or 1) - int il = tid%8; // each half has 8 parts, one per scale - int n = 4; // 4 scales at a time (and 4 sums) - int l0 = n*il; // offset into half-block, 0..28 - int is = 8*ip + l0/16; // 0, 1, 8, 9 - - int y_offset = 128*ip + l0; - int q_offset_l = 64*ip + l0; - int q_offset_h = 32*ip + l0; - - for (int i = ix; i < nb; i += BLOCK_STRIDE) { - - global uint8_t * q1 = x[i].ql + q_offset_l; - global uint8_t * q2 = q1 + QK_K/8; - global uint8_t * qh = x[i].qh + q_offset_h; - global int8_t * sc = x[i].scales + is; - - global float * y = yy + i * QK_K + y_offset; - - float dall = x[i].d; - - float4 sums = {0.f, 0.f, 0.f, 0.f}; - - sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & kmask1) << 4)) - 32.f); - sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & kmask2) << 2)) - 32.f); - sums.s2 += y[0+64] * ((float)((q1[0] >> 4) | ((qh[0] & kmask3) << 0)) - 32.f); - sums.s3 += y[0+96] * ((float)((q2[0] >> 4) | ((qh[0] & kmask4) >> 2)) - 32.f); - - sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & kmask1) << 4)) - 32.f); - sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & kmask2) << 2)) - 32.f); - sums.s2 += y[1+64] * ((float)((q1[1] >> 4) | ((qh[1] & kmask3) << 0)) - 32.f); - sums.s3 += y[1+96] * ((float)((q2[1] >> 4) | ((qh[1] & kmask4) >> 2)) - 32.f); - - sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & kmask1) << 4)) - 32.f); - sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & kmask2) << 2)) - 32.f); - sums.s2 += y[2+64] * ((float)((q1[2] >> 4) | ((qh[2] & kmask3) << 0)) - 32.f); - sums.s3 += y[2+96] * ((float)((q2[2] >> 4) | ((qh[2] & kmask4) >> 2)) - 32.f); - - sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & kmask1) << 4)) - 32.f); - sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & kmask2) << 2)) - 32.f); - sums.s2 += y[3+64] * ((float)((q1[3] >> 4) | ((qh[3] & kmask3) << 0)) - 32.f); - sums.s3 += y[3+96] * ((float)((q2[3] >> 4) | ((qh[3] & kmask4) >> 2)) - 32.f); - - sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]); - } - - float tot = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - dst[r1*ne0 + im*ne0*ne1 + row] = tot; - } -} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl deleted file mode 100644 index cd4e0afba..000000000 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl +++ /dev/null @@ -1,26 +0,0 @@ -// 16-bit transpose, loading/storing a 4x4 tile of elements - -#pragma OPENCL EXTENSION cl_khr_fp16 : enable - -kernel void kernel_transpose_16( - __read_only image1d_buffer_t input, - __write_only image1d_buffer_t output, - const uint rows, - const uint cols -) { - - const int i = get_global_id(0); - const int j = get_global_id(1); - const int i_2 = i<<2; - const int j_2 = j<<2; - - half4 temp0 = read_imageh(input, (j_2+0)*cols+i); - half4 temp1 = read_imageh(input, (j_2+1)*cols+i); - half4 temp2 = read_imageh(input, (j_2+2)*cols+i); - half4 temp3 = read_imageh(input, (j_2+3)*cols+i); - - write_imageh(output, (i_2+0)*rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); - write_imageh(output, (i_2+1)*rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); - write_imageh(output, (i_2+2)*rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); - write_imageh(output, (i_2+3)*rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); -} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl deleted file mode 100644 index 914ec0193..000000000 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl +++ /dev/null @@ -1,25 +0,0 @@ -// 32-bit transpose, loading/storing a 4x4 tile of elements - -kernel void kernel_transpose_32( - __read_only image1d_buffer_t input, - __write_only image1d_buffer_t output, - const uint rows, - const uint cols -) { - - const int i = get_global_id(0); - const int j = get_global_id(1); - const int i_2 = i<<2; - const int j_2 = j<<2; - - float4 temp0 = read_imagef(input, (j_2+0)*cols+i); - float4 temp1 = read_imagef(input, (j_2+1)*cols+i); - float4 temp2 = read_imagef(input, (j_2+2)*cols+i); - float4 temp3 = read_imagef(input, (j_2+3)*cols+i); - - write_imagef(output, (i_2+0)*rows+j, (float4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); - write_imagef(output, (i_2+1)*rows+j, (float4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); - write_imagef(output, (i_2+2)*rows+j, (float4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); - write_imagef(output, (i_2+3)*rows+j, (float4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); - -} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl deleted file mode 100644 index d3bd1fabb..000000000 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl +++ /dev/null @@ -1,35 +0,0 @@ -// 32-bit transpose, loading/storing a 4x4 tile of elements -// Only used for activations -// converts to FP16 -// also adds zero padding for non multiple of 8 prompt lengths -#pragma OPENCL EXTENSION cl_khr_fp16 : enable - -kernel void kernel_transpose_32_16(__read_only image1d_buffer_t input, __write_only image1d_buffer_t output, const uint rows, const uint cols, const uint padded_rows) { - - const int i = get_global_id(0); - const int j = get_global_id(1); - const int i_2 = i<<2; - const int j_2 = j<<2; - half4 temp0 = {0,0,0,0}; // initialize outputs to 0 - half4 temp1 = {0,0,0,0}; - half4 temp2 = {0,0,0,0}; - half4 temp3 = {0,0,0,0}; - - if((j_2+0)*cols+i*4+3 < rows*cols*16){ // only load from a valid location. Otherwise keep register data as 0 - temp0 = read_imageh(input, (j_2+0)*cols+i); - } - if((j_2+1)*cols+i*4+3 < rows*cols*16){ - temp1 = read_imageh(input, (j_2+1)*cols+i); - } - if((j_2+2)*cols+i*4+3 < rows*cols*16){ - temp2 = read_imageh(input, (j_2+2)*cols+i); - } - if((j_2+3)*cols+i*4+3 < rows*cols*16){ - temp3 = read_imageh(input, (j_2+3)*cols+i); - } - - write_imageh(output, (i_2+0)*padded_rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); // no conditionals for output, includes zero padding - write_imageh(output, (i_2+1)*padded_rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); - write_imageh(output, (i_2+2)*padded_rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); - write_imageh(output, (i_2+3)*padded_rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); -} diff --git a/ggml/src/ggml-opencl/kernels/group_norm.cl b/ggml/src/ggml-opencl/kernels/group_norm.cl new file mode 100644 index 000000000..57c9df4d3 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/group_norm.cl @@ -0,0 +1,72 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +// Workgroup must be a subgroup +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_32 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_group_norm( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne, + int group_size, + float eps +) { + src0 = (global float *)((global char *)src0 + offset0); + dst = (global float *)((global char *)dst + offsetd); + + int start = get_group_id(0) * group_size; + int end = start + group_size; + + start += get_local_id(0); + + if (end >= ne) { + end = ne; + } + + float tmp = 0.0f; + + for (int j = start; j < end; j += get_local_size(0)) { + tmp += src0[j]; + } + + tmp = sub_group_reduce_add(tmp); + + const float mean = tmp / group_size; + tmp = 0.0f; + + for (int j = start; j < end; j += get_local_size(0)) { + float xi = src0[j] - mean; + dst[j] = xi; + tmp += xi * xi; + } + + tmp = sub_group_reduce_add(tmp); + + const float variance = tmp / group_size; + const float scale = 1.0f/sqrt(variance + eps); + for (int j = start; j < end; j += get_local_size(0)) { + dst[j] *= scale; + } +} diff --git a/ggml/src/ggml-opencl/kernels/im2col_f16.cl b/ggml/src/ggml-opencl/kernels/im2col_f16.cl new file mode 100644 index 000000000..b84c89846 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/im2col_f16.cl @@ -0,0 +1,57 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +kernel void kernel_im2col_f16( + global float * src1, + ulong offset1, + global half * dst, + ulong offsetd, + ulong batch_offset, + ulong delta_offset, + long IW, + long IH, + long IC, + long OW, + long OH, + long KW, + long KH, + long pelements, + long CHW, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1 +) { + long i = get_global_id(0); + if (i >= pelements) { + return; + } + + src1 = (global float*)((global char*)src1 + offset1); + dst = (global half*)((global char*)dst + offsetd); + + long ksize = OW * (KH > 1 ? KW : 1); + long kx = i / ksize; + long kd = kx * ksize; + long ky = (i - kd) / OW; + long ix = i % OW; + + long oh = get_group_id(1); + long batch = get_group_id(2) / IC; + long ic = get_group_id(2) % IC; + + long iiw = ix * s0 + kx * d0 - p0; + long iih = oh * s1 + ky * d1 - p1; + + long offset_dst = + ((batch * OH + oh) * OW + ix) * CHW + + (ic * (KW * KH) + ky * KW + kx); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + long offset_src = ic * delta_offset + batch * batch_offset; + dst[offset_dst] = src1[offset_src + iih * IW + iiw]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/im2col_f32.cl b/ggml/src/ggml-opencl/kernels/im2col_f32.cl new file mode 100644 index 000000000..4bf65e4ea --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/im2col_f32.cl @@ -0,0 +1,57 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +kernel void kernel_im2col_f32( + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + ulong batch_offset, + ulong delta_offset, + long IW, + long IH, + long IC, + long OW, + long OH, + long KW, + long KH, + long pelements, + long CHW, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1 +) { + long i = get_global_id(0); + if (i >= pelements) { + return; + } + + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + long ksize = OW * (KH > 1 ? KW : 1); + long kx = i / ksize; + long kd = kx * ksize; + long ky = (i - kd) / OW; + long ix = i % OW; + + long oh = get_group_id(1); + long batch = get_group_id(2) / IC; + long ic = get_group_id(2) % IC; + + long iiw = ix * s0 + kx * d0 - p0; + long iih = oh * s1 + ky * d1 - p1; + + long offset_dst = + ((batch * OH + oh) * OW + ix) * CHW + + (ic * (KW * KH) + ky * KW + kx); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + long offset_src = ic * delta_offset + batch * batch_offset; + dst[offset_dst] = src1[offset_src + iih * IW + iiw]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul.cl b/ggml/src/ggml-opencl/kernels/mul.cl new file mode 100644 index 000000000..2a2b4eb70 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul.cl @@ -0,0 +1,79 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// mul +//------------------------------------------------------------------------------ +kernel void kernel_mul( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int i13 = i03 % ne13; + int i12 = i02 % ne12; + int i11 = i01 % ne11; + + global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const int i10 = i0 % ne10; + *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) * *((global float *)(src1_ptr + i10*nb10)); + } +} + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_mul_row( + global float4 * src0, + ulong offset0, + global float4 * src1, + ulong offset1, + global float4 * dst, + ulong offsetd, + int ne +) { + src0 = (global float4*)((global char*)src0 + offset0); + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float4*)((global char*)dst + offsetd); + + // This performs better than using %. + uint gid = get_global_id(0); + uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne + dst[gid] = src0[gid] * src1[idx1]; +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl b/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl similarity index 100% rename from ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl rename to ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl new file mode 100644 index 000000000..9393b5494 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl @@ -0,0 +1,118 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define N_F16_F16 4 + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f16_f16( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3) +{ + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int rb = get_group_id(1)*N_F16_F16; + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global half * x = (global half *) (src0 + offset_src0); + + if (ne00 < 128) { + for (int row = 0; row < N_F16_F16; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global half * y = (global half *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += (half) x[i] * (half) y[i]; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + global half4 * x4 = (global half4 *)x; + for (int row = 0; row < N_F16_F16; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global half * y = (global half *) (src1 + offset_src1); + global half4 * y4 = (global half4 *) y; + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += (half) x4[i].s0 * y4[i].s0; + sumf += (half) x4[i].s1 * y4[i].s1; + sumf += (half) x4[i].s2 * y4[i].s2; + sumf += (half) x4[i].s3 * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (half) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl new file mode 100644 index 000000000..e52d3c6d4 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl @@ -0,0 +1,118 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define N_F16_F32 4 + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f16_f32( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int rb = get_group_id(1)*N_F16_F32; + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global half * x = (global half *) (src0 + offset_src0); + + if (ne00 < 128) { + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += convert_float(x[i]) * y[i]; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + global half4 * x4 = (global half4 *)x; + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + global float4 * y4 = (global float4 *) y; + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += convert_float(x4[i].s0) * y4[i].s0; + sumf += convert_float(x4[i].s1) * y4[i].s1; + sumf += convert_float(x4[i].s2) * y4[i].s2; + sumf += convert_float(x4[i].s3) * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (float) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl new file mode 100644 index 000000000..28d30212c --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl @@ -0,0 +1,94 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f16_f32_1row( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global half * x = (global half *) (src0 + offset_src0); + global float * y = (global float *) (src1 + offset_src1); + + float sumf = 0; + if (ne00 < 128) { + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += (float) x[i] * (float) y[i]; + } + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } else { + global half4 * x4 = (global half4 *) x; + global float4 * y4 = (global float4 *) y; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += (float) x4[i].s0 * y4[i].s0; + sumf += (float) x4[i].s1 * y4[i].s1; + sumf += (float) x4[i].s2 * y4[i].s2; + sumf += (float) x4[i].s3 * y4[i].s3; + } + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (float) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl new file mode 100644 index 000000000..cdf8197c4 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl @@ -0,0 +1,84 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +// Assumes row size (ne00) is a multiple of 4 +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f16_f32_l4( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int nrows = ne11; + int r0 = get_group_id(0); + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global half4 * x4 = (global half4 *) (src0 + offset_src0); + + for (int r1 = 0; r1 < nrows; ++r1) { + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float4 * y4 = (global float4 *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += convert_float(x4[i].s0) * y4[i].s0; + sumf += convert_float(x4[i].s1) * y4[i].s1; + sumf += convert_float(x4[i].s2) * y4[i].s2; + sumf += convert_float(x4[i].s3) * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl new file mode 100644 index 000000000..ec71b8756 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl @@ -0,0 +1,118 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define N_F32_F32 4 + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f32_f32( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int rb = get_group_id(1)*N_F32_F32; + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global float * x = (global float *) (src0 + offset_src0); + + if (ne00 < 128) { + for (int row = 0; row < N_F32_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += (float) x[i] * (float) y[i]; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + global float4 * x4 = (global float4 *)x; + for (int row = 0; row < N_F32_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + global float4 * y4 = (global float4 *) y; + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += (float) x4[i].s0 * y4[i].s0; + sumf += (float) x4[i].s1 * y4[i].s1; + sumf += (float) x4[i].s2 * y4[i].s2; + sumf += (float) x4[i].s3 * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (float) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl new file mode 100644 index 000000000..7ccf41efb --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_id_q4_0_f32_8x_flat.cl @@ -0,0 +1,283 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +// This function requires the original shuffled weights. +// As a reminder, the original weights are shuffled so that (q[0], q[16]) are +// packed together in a byte, so are (q[1], q[17]) and so on. +inline float block_q_4_0_dot_y_flat( + global uchar * x, + global half * dh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + global ushort * qs = ((global ushort *)x + il/2); + float acc = 0.f; + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +// +// This variant outputs 8 values. +// +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 8 // each SIMD group works on 8 rows +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // subgroup size +#elif defined (ADRENO_GPU) +#define N_DST 8 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_8x_flat( + global char * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = 0; + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float8 sumf = 0.f; + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0.f; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); + sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); + sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); + sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); + + sumf.s4 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); + sumf.s5 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); + sumf.s6 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); + sumf.s7 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); + + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + float8 tot = (float8)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_id_q4_0_f32_8x_flat( + global char * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global char * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb02, + int ne10, + int ne11, + int ne12, + ulong nb11, + ulong nb12, + int ne20, + int ne21, + ulong nb21, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float *)((global char *)src1 + offset1); + src2 = (global char *)((global char *)src2 + offset2); + dst = (global float *)((global char *)dst + offsetd); + + const int iid1 = get_group_id(2)/ne20; + const int idx = get_group_id(2)%ne20; + + const int i02 = ((global int *)(src2 + iid1*nb21))[idx]; + + const int i11 = idx%ne11; + const int i12 = iid1; + + const int i1 = idx; + const int i2 = i12; + + global char * src0_q_cur = src0_q + (i02*nb02/nb00)*(QK4_0/2); + global half * src0_d_cur = src0_d + (i02*nb02/nb00); + global float * src1_cur = (global float *)((global char *) src1 + i11*nb11 + i12*nb12); + global float * dst_cur = dst + i1*ne0 + i2*ne1*ne0; + + mul_vec_q_n_f32_8x_flat(src0_q_cur, src0_d_cur, src1_cur, dst_cur, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl new file mode 100644 index 000000000..52141e0ed --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl @@ -0,0 +1,192 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +//------------------------------------------------------------------------------ +// mul_vec_q_n_f32 +//------------------------------------------------------------------------------ +// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_4_0_dot_y( + global struct block_q4_0 * qb_curr, + float sumy, + private float * yl, + int il +) { + float d = qb_curr->d; + float2 acc = 0.f; + global ushort * qs = ((global ushort *)qb_curr + 1 + il/2); + for (int i = 0; i < 8; i+=2) { + acc.s0 += yl[i + 0] * (qs[i / 2] & 0x000F) + + yl[i + 1] * (qs[i / 2] & 0x0F00); + acc.s1 += yl[i + 8] * (qs[i / 2] & 0x00F0) + + yl[i + 9] * (qs[i / 2] & 0xF000); + } + return d * (sumy * -8.f + acc.s0 + acc.s1); +} + +#ifdef INTEL_GPU +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + + const ulong nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global + // id of a SIMD group in the grid. + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; // src1 vector cache + float sumf[N_DST]={0.f}; + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_0 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + for (int i = 0; i < 8; i += 2) { + sumy += yb[i] + yb[i+1]; + yl[i+0] = yb[i+ 0]; + yl[i+1] = yb[i+ 1]/256.f; + sumy += yb[i+16] + yb[i+17]; + yl[i+8] = yb[i+16]/16.f; + yl[i+9] = yb[i+17]/4096.f; + } + + for (int row = 0; row < N_DST; row++) { + sumf[row] += block_q_4_0_dot_y(x+ib+row*nb, sumy, yl, il); + } + + // One thread in a SIMD group (i.e., subgroup) handles a half block, + // hence then entire SIMD group handles SIMDWIDTH/2 blocks. + // y points to the activation matrix (of type float). Therefore for + // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because + // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of + // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size. + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + // The above does not work for Adreno - it produces incorrect results for + // row = 1, 2, 3 and only row = 0 gives the correct result. + // If N_DST is changed, the below array must be initialized accordingly. + // This also seems to perform better on Intel. + float tot[N_DST] = { + sub_group_reduce_add(sumf[0]), sub_group_reduce_add(sumf[1]), + sub_group_reduce_add(sumf[2]), sub_group_reduce_add(sumf[3])}; + for (int row = 0; row < N_DST; ++row) { + if (get_sub_group_local_id() == 0 && first_row + row < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot[row]; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl new file mode 100644 index 000000000..3eebab8f0 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl @@ -0,0 +1,307 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +inline float mm_block_q_4_0_dot_y_flat( + global uchar * x, + global half * dh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + global ushort * qs = ((global ushort *)x + il/2); + float acc = 0.f; + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +#ifdef INTEL_GPU +#define N_DST 16 // each SIMD group works on 8 rows (in weights matrix) +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 16 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif +// +// This variant performs 1d blocking with 16x output. +// Eeach simdgroup outputs 16 values on `n0` dim (row in the output matrix). +// +inline void mul_mat_q_n_f32_1d_16x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const int nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of + // a SIMD group in the grid. Each SIMD group produces N_DST values in the + // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. + // Currently with llama2 7B, im is always 0. + // TODO: how to handle im/gqa*(nb*ne0)? + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float16 sumf = (float16)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0.f; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); + sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); + sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); + sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); + + sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); + sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); + sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); + sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); + + sumf.s8 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 8*nb*QK4_0/2, d + ib + 8*nb, sumy, yl, il); + sumf.s9 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 9*nb*QK4_0/2, d + ib + 9*nb, sumy, yl, il); + sumf.sa += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 10*nb*QK4_0/2, d + ib + 10*nb, sumy, yl, il); + sumf.sb += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 11*nb*QK4_0/2, d + ib + 11*nb, sumy, yl, il); + + sumf.sc += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 12*nb*QK4_0/2, d + ib + 12*nb, sumy, yl, il); + sumf.sd += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 13*nb*QK4_0/2, d + ib + 13*nb, sumy, yl, il); + sumf.se += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 14*nb*QK4_0/2, d + ib + 14*nb, sumy, yl, il); + sumf.sf += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 15*nb*QK4_0/2, d + ib + 15*nb, sumy, yl, il); + + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + float16 tot = (float16)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7), + + sub_group_reduce_add(sumf.s8), sub_group_reduce_add(sumf.s9), + sub_group_reduce_add(sumf.sa), sub_group_reduce_add(sumf.sb), + sub_group_reduce_add(sumf.sc), sub_group_reduce_add(sumf.sd), + sub_group_reduce_add(sumf.se), sub_group_reduce_add(sumf.sf) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + + if (first_row + 8 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 8] = tot.s8; + } + if (first_row + 9 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 9] = tot.s9; + } + if (first_row + 10 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 10] = tot.sa; + } + if (first_row + 11 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 11] = tot.sb; + } + + if (first_row + 12 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 12] = tot.sc; + } + if (first_row + 13 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 13] = tot.sd; + } + if (first_row + 14 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 14] = tot.se; + } + if (first_row + 15 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 15] = tot.sf; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_1d_16x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_mat_q_n_f32_1d_16x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl new file mode 100644 index 000000000..38024d00a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl @@ -0,0 +1,265 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +inline float mm_block_q_4_0_dot_y_flat( + global uchar * x, + global half * dh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + global ushort * qs = ((global ushort *)x + il/2); + float acc = 0.f; + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +#ifdef INTEL_GPU +#define N_DST 8 // each SIMD group works on 8 rows (in weights matrix) +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 8 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif +// +// This variant performs 1d blocking with 8x output. +// Eeach simdgroup outputs 8 values on `n0` dim (row in the output matrix). +// +inline void mul_mat_q_n_f32_1d_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const int nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of + // a SIMD group in the grid. Each SIMD group produces N_DST values in the + // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. + // Currently with llama2 7B, im is always 0. + // TODO: how to handle im/gqa*(nb*ne0)? + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float8 sumf = (float8)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0.f; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); + sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); + sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); + sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); + + sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); + sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); + sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); + sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); + + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + float8 tot = (float8)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_1d_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_mat_q_n_f32_1d_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl new file mode 100644 index 000000000..aed1ce7b2 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl @@ -0,0 +1,272 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +// This function requires the original shuffled weights. +// As a reminder, the original weights are shuffled so that (q[0], q[16]) are +// packed together in a byte, so are (q[1], q[17]) and so on. +inline float block_q_4_0_dot_y_flat( + global uchar * x, + global half * dh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + global ushort * qs = ((global ushort *)x + il/2); + float acc = 0.f; + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +// +// This variant outputs 8 values. +// +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 8 // each SIMD group works on 8 rows +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 32 +#elif defined (ADRENO_GPU) +#define N_DST 8 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of + // a SIMD group in the grid. Each SIMD group produces N_DST values in the + // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. + // Currently with llama2 7B, im is always 0. + // TODO: how to handle im/gqa*(nb*ne0)? + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float8 sumf = 0.f; + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0.f; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); + sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); + sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); + sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); + + sumf.s4 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); + sumf.s5 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); + sumf.s6 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); + sumf.s7 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); + + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + float8 tot = (float8)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl new file mode 100644 index 000000000..929552179 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl @@ -0,0 +1,254 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +// +// This variant unrolls the loops and uses vector types instead of pointers. +// It improves performance on Adreno but not so much on Intel. +// +inline float block_q_4_0_dot_y_v( + global struct block_q4_0 * qb_curr, + float sumy, + float16 yl, + int il +) { + float d = qb_curr->d; + float acc = 0.f; + global ushort * qs = ((global ushort *)qb_curr + 1 + il/2); + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_v( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global + // id of a SIMD group in the grid. + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; // src1 vector cache + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_0 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q_4_0_dot_y_v(x+ib+0*nb, sumy, yl, il); + sumf.s1 += block_q_4_0_dot_y_v(x+ib+1*nb, sumy, yl, il); + sumf.s2 += block_q_4_0_dot_y_v(x+ib+2*nb, sumy, yl, il); + sumf.s3 += block_q_4_0_dot_y_v(x+ib+3*nb, sumy, yl, il); + + // One thread in a SIMD group (i.e., subgroup) handles a half block, + // hence then entire SIMD group handles SIMDWIDTH/2 blocks. + // y points to the activation matrix (of type float). Therefore for + // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because + // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of + // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size. + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + // The above does not work for Adreno - it produces incorrect results for + // row = 1, 2, 3 and only row = 0 gives the correct result. + // If N_DST is changed, the below array must be initialized accordingly. + // This also seems to perform better on Intel. + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_v( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_v(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl new file mode 100644 index 000000000..8a17b9aae --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl @@ -0,0 +1,190 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q6_K +//------------------------------------------------------------------------------ +// 6-bit quantization +// weight is represented as x = a * q +// 16 blocks of 16 elements each +// Effectively 6.5625 bits per weight +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + half d; // super-block scale +} block_q6_K; + +//------------------------------------------------------------------------------ +// kernel_mul_mv_q6_K_f32 +//------------------------------------------------------------------------------ + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 1 // number of rows each SIMD group works on +#define N_SIMDGROUP 2 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // SIMD group size +#elif defined (ADRENO_GPU) +#define N_DST 1 +#define N_SIMDGROUP 2 +#define N_SIMDWIDTH 64 +#endif + +#define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q6_K_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + uchar kmask1 = 0x03; + uchar kmask2 = 0x0C; + uchar kmask3 = 0x30; + uchar kmask4 = 0xC0; + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int row = N_SIMDGROUP * r0 + get_sub_group_id(); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global block_q6_K * x = (global block_q6_K *) src0 + row*nb + offset_src0; + global float * yy = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float sumf = 0; + + // For Q6_K quantization, 16 values forms a subblock, 16 subblock forms a + // block. Values in a subblock shares a scale that is quantized with 8 bits; + // the entire block shares a single floating point scale. + // For work distribution, each thread processes a subblock (16 weights), hence + // 16 threads process a (super) block -- a subgroup thus handles SIMDWIDTH/16 + // (super) blocks -- this is the block stride. + // The 16 threads that process a (super) block are split into 2 portions, each has + // 8 threads; each portion works on 8 subblocks. + // For subgroup of 16 threads, the entire subgroup works on a single (super) block + // before moving to the next (super) block. Thread0 - thread7 work on the + // first 8 subblocks; thread8 - thread15 works on the last 8 subblocks. + // Thread0 - thread3 work on subblocks 0, 2, 4, 6; thread4 - thread7 work on + // subblocks 1, 3, 5, 7. Each thread does not work on an entire subblock, but + // works on a total of 16 weight values. + int tid = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0 + int ix = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1 + int ip = tid/8; // first or second half of (super) block (0 or 1) + int il = tid%8; // each half has 8 parts, one per scale + int n = 4; // 4 scales at a time (and 4 sums) + int l0 = n*il; // offset into half-block, 0..28 + int is = 8*ip + l0/16; // 0, 1, 8, 9 + + int y_offset = 128*ip + l0; + int q_offset_l = 64*ip + l0; + int q_offset_h = 32*ip + l0; + + for (int i = ix; i < nb; i += BLOCK_STRIDE) { + + global uint8_t * q1 = x[i].ql + q_offset_l; + global uint8_t * q2 = q1 + QK_K/8; + global uint8_t * qh = x[i].qh + q_offset_h; + global int8_t * sc = x[i].scales + is; + + global float * y = yy + i * QK_K + y_offset; + + float dall = x[i].d; + + float4 sums = {0.f, 0.f, 0.f, 0.f}; + + sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & kmask1) << 4)) - 32.f); + sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & kmask2) << 2)) - 32.f); + sums.s2 += y[0+64] * ((float)((q1[0] >> 4) | ((qh[0] & kmask3) << 0)) - 32.f); + sums.s3 += y[0+96] * ((float)((q2[0] >> 4) | ((qh[0] & kmask4) >> 2)) - 32.f); + + sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & kmask1) << 4)) - 32.f); + sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & kmask2) << 2)) - 32.f); + sums.s2 += y[1+64] * ((float)((q1[1] >> 4) | ((qh[1] & kmask3) << 0)) - 32.f); + sums.s3 += y[1+96] * ((float)((q2[1] >> 4) | ((qh[1] & kmask4) >> 2)) - 32.f); + + sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & kmask1) << 4)) - 32.f); + sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & kmask2) << 2)) - 32.f); + sums.s2 += y[2+64] * ((float)((q1[2] >> 4) | ((qh[2] & kmask3) << 0)) - 32.f); + sums.s3 += y[2+96] * ((float)((q2[2] >> 4) | ((qh[2] & kmask4) >> 2)) - 32.f); + + sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & kmask1) << 4)) - 32.f); + sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & kmask2) << 2)) - 32.f); + sums.s2 += y[3+64] * ((float)((q1[3] >> 4) | ((qh[3] & kmask3) << 0)) - 32.f); + sums.s3 += y[3+96] * ((float)((q2[3] >> 4) | ((qh[3] & kmask4) >> 2)) - 32.f); + + sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]); + } + + float tot = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[r1*ne0 + im*ne0*ne1 + row] = tot; + } +} diff --git a/ggml/src/ggml-opencl/kernels/norm.cl b/ggml/src/ggml-opencl/kernels/norm.cl new file mode 100644 index 000000000..43167ba4d --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/norm.cl @@ -0,0 +1,81 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// norm +//------------------------------------------------------------------------------ +kernel void kernel_norm( + global void * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + float eps, + local float * sum +) { + src0 = (global void*)((global char*)src0 + offset0); + dst = (global void*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float * x = (global float *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01); + + // MEAN + // parallel sum + sum[get_local_id(0)] = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + sum[get_local_id(0)] += x[i00]; + } + // reduce + barrier(CLK_LOCAL_MEM_FENCE); + for (uint i = get_local_size(0)/2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + barrier(CLK_LOCAL_MEM_FENCE); + } + float mean = sum[0] / ne00; + + // recenter and VARIANCE + barrier(CLK_LOCAL_MEM_FENCE); + global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + sum[get_local_id(0)] = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + y[i00] = x[i00] - mean; + sum[get_local_id(0)] += y[i00] * y[i00]; + } + + // reduce + barrier(CLK_LOCAL_MEM_FENCE); + for (uint i = get_local_size(0)/2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + barrier(CLK_LOCAL_MEM_FENCE); + } + float variance = sum[0] / ne00; + + float scale = 1.0f/sqrt(variance + eps); + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + y[i00] = y[i00] * scale; + } +} diff --git a/ggml/src/ggml-opencl/kernels/pad.cl b/ggml/src/ggml-opencl/kernels/pad.cl new file mode 100644 index 000000000..747fa7feb --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/pad.cl @@ -0,0 +1,30 @@ +kernel void kernel_pad( + global const void * src0_ptr, + ulong src0_offset, + global void * dst_ptr, + ulong dst_offset, + int s_ne0, int s_ne1, int s_ne2, + int d_ne0, int d_ne1, int d_ne2 +) { + global const float * src0 = (global const float *)((global const char *)src0_ptr + src0_offset); + global float * dst = (global float *)((global char *)dst_ptr + dst_offset); + + int nidx = get_global_id(0); + int idx_d1 = get_group_id(1); + int idx_d2 = get_group_id(2); + + if (nidx >= d_ne0) { + return; + } + + int dst_el_offset = nidx + idx_d1 * d_ne0 + idx_d2 * d_ne0 * d_ne1; + + bool in_src_bounds = (nidx < s_ne0) && (idx_d1 < s_ne1) && (idx_d2 < s_ne2); + + if (in_src_bounds) { + int src_el_offset = nidx + idx_d1 * s_ne0 + idx_d2 * s_ne0 * s_ne1; + dst[dst_el_offset] = src0[src_el_offset]; + } else { + dst[dst_el_offset] = 0.0f; + } +} diff --git a/ggml/src/ggml-opencl/kernels/relu.cl b/ggml/src/ggml-opencl/kernels/relu.cl new file mode 100644 index 000000000..60ff28a61 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/relu.cl @@ -0,0 +1,16 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// relu +//------------------------------------------------------------------------------ +kernel void kernel_relu( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = fmax(0.0f, src0[get_global_id(0)]); +} diff --git a/ggml/src/ggml-opencl/kernels/repeat.cl b/ggml/src/ggml-opencl/kernels/repeat.cl new file mode 100644 index 000000000..079498f5a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/repeat.cl @@ -0,0 +1,39 @@ +kernel void kernel_repeat( + global const char * src0_data_in, + global char * dst_data_in, + ulong src0_offset, + ulong dst_offset, + int src0_ne0, int src0_ne1, int src0_ne2, int src0_ne3, + ulong src0_nb0, ulong src0_nb1, ulong src0_nb2, ulong src0_nb3, + int dst_ne0, int dst_ne1, int dst_ne2, int dst_ne3, + ulong dst_nb0, ulong dst_nb1, ulong dst_nb2, ulong dst_nb3 +) { + global const char * src0_data = src0_data_in + src0_offset; + global char * dst_data = dst_data_in + dst_offset; + + const int d3 = get_global_id(2); + const int d2 = get_global_id(1); + const int d1 = get_global_id(0); + + if (d3 >= dst_ne3 || d2 >= dst_ne2 || d1 >= dst_ne1) { + return; + } + + const int s3 = d3 % src0_ne3; + const int s2 = d2 % src0_ne2; + const int s1 = d1 % src0_ne1; + + const global char * p_src0_slice = src0_data + (ulong)s3*src0_nb3 + (ulong)s2*src0_nb2 + (ulong)s1*src0_nb1; + global char * p_dst_slice = dst_data + (ulong)d3*dst_nb3 + (ulong)d2*dst_nb2 + (ulong)d1*dst_nb1; + + for (int d0 = 0; d0 < dst_ne0; ++d0) { + // Determine source index for dimension 0 based on tiling/broadcasting. + const int s0 = d0 % src0_ne0; + + const global char * restrict current_src_el_ptr = p_src0_slice + (ulong)s0*src0_nb0; + global char * restrict current_dst_el_ptr = p_dst_slice + (ulong)d0*dst_nb0; + for (int k = 0; k < src0_nb0; ++k) { + current_dst_el_ptr[k] = current_src_el_ptr[k]; + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/rms_norm.cl b/ggml/src/ggml-opencl/kernels/rms_norm.cl new file mode 100644 index 000000000..9d21f3398 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/rms_norm.cl @@ -0,0 +1,96 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// rms_norm +//------------------------------------------------------------------------------ +// This kernel depends on subgroup size. +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_32 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_rms_norm( + global void * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + float eps, + local float * sum // Note, the size depends on number of subgroups +) { + src0 = (global void*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float4 * x = (global float4 *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01); + global float * x_scalar = (global float *) x; + float4 sumf = 0; + float all_sum = 0; + + // parallel sum + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + sumf += x[i00] * x[i00]; + } + all_sum = sumf.s0 + sumf.s1 + sumf.s2 + sumf.s3; + all_sum = sub_group_reduce_add(all_sum); + if (get_sub_group_local_id() == 0) { + sum[get_sub_group_id()] = all_sum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + // broadcast + for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + } + if (get_local_id(0) == 0) { + for (int i = 4 * (ne00 / 4); i < ne00; i++) { + sum[0] += x_scalar[i]; + } + sum[0] /= ne00; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + const float mean = sum[0]; + const float scale = 1.0f/sqrt(mean + eps); + + global float4 * y = (global float4 *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + global float * y_scalar = (global float *) y; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + y[i00] = x[i00] * scale; + } + if (get_local_id(0) == 0) { + for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) { + y_scalar[i00] = x_scalar[i00] * scale; + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/rope.cl b/ggml/src/ggml-opencl/kernels/rope.cl new file mode 100644 index 000000000..0247730c0 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/rope.cl @@ -0,0 +1,721 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// kernel_rope +//------------------------------------------------------------------------------ +float rope_yarn_ramp(float low, float high, int i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +float2 rope_yarn( + float theta_extrap, float freq_scale, float2 corr_dims, int i0, float ext_factor, float mscale +) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims.s0, corr_dims.s1, i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); + } + return (float2)(cos(theta) * mscale, sin(theta) * mscale); +} + +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) { + return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base)); +} + +float2 rope_yarn_corr_dims( + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow +) { + // start and end correction dims + return (float2)( + max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))), + min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))) + ); +} + +kernel void kernel_rope_norm_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + float theta = theta_base * pow(freq_base, inv_ndims*i0); + + float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + float x0 = src[0]; + float x1 = src[1]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_norm_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + float theta = theta_base * pow(freq_base, inv_ndims*i0); + + float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + float x0 = src[0]; + float x1 = src[1]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_neox_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_neox_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_multi_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const int sector = (i0 / 2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + theta_base = pos[i2]; + } + else if (sector >= sections.s0 && sector < sec_w) { + theta_base = pos[i2 + ne2 * 1]; + } + else if (sector >= sec_w && sector < sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 2]; + } + else if (sector >= sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 3]; + } + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_multi_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const int sector = (i0 / 2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + theta_base = pos[i2]; + } + else if (sector >= sections.s0 && sector < sec_w) { + theta_base = pos[i2 + ne2 * 1]; + } + else if (sector >= sec_w && sector < sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 2]; + } + else if (sector >= sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 3]; + } + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_vision_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + int ic = i0/2; + + const int sector = (i0/2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + const int p = sector; + theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p); + } else if (sector >= sections.s0 && sector < sec_w) { + const int p = sector - sections.s0; + theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p); + } + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } +} + +kernel void kernel_rope_vision_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + int ic = i0/2; + + const int sector = (i0/2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + const int p = sector; + theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p); + } else if (sector >= sections.s0 && sector < sec_w) { + const int p = sector - sections.s0; + theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p); + } + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } +} diff --git a/ggml/src/ggml-opencl/kernels/scale.cl b/ggml/src/ggml-opencl/kernels/scale.cl new file mode 100644 index 000000000..8cfd518fa --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/scale.cl @@ -0,0 +1,16 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// scale +//------------------------------------------------------------------------------ +kernel void kernel_scale( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd, + float scale +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + dst[get_global_id(0)] = src0[get_global_id(0)] * scale; +} diff --git a/ggml/src/ggml-opencl/kernels/sigmoid.cl b/ggml/src/ggml-opencl/kernels/sigmoid.cl new file mode 100644 index 000000000..e3f669dde --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/sigmoid.cl @@ -0,0 +1,29 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// sigmoid +//------------------------------------------------------------------------------ + +kernel void kernel_sigmoid_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = 1.0f / (1.0f + exp(-src0[get_global_id(0)])); +} + +kernel void kernel_sigmoid_f16( + global half * src0, + ulong offset0, + global half * dst, + ulong offsetd +) { + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = 1.0f / (1.0f + exp(-src0[get_global_id(0)])); +} diff --git a/ggml/src/ggml-opencl/kernels/silu.cl b/ggml/src/ggml-opencl/kernels/silu.cl new file mode 100644 index 000000000..1d95e1b50 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/silu.cl @@ -0,0 +1,30 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// silu +//------------------------------------------------------------------------------ +kernel void kernel_silu( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + float x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x / (1.0f + exp(-x)); +} + +kernel void kernel_silu_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + float4 x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x / (1.0f + exp(-x)); +} diff --git a/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl b/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl new file mode 100644 index 000000000..62c05369a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl @@ -0,0 +1,87 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_soft_max_4_f16( + global float * src0, + ulong offset0, + global half * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float *)((global char *)src0 + offset0); + src1 = (global half *)((global char *)src1 + offset1); + dst = (global float *)((global char *)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + global half4 * pmask = (global char *)src1 != (global char *)src0 ? (global half4 *)(src1 + i01*ne00) : 0; + global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float4 lmax4 = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)); + } + float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3)); + + const float max = sub_group_reduce_max(lmax); + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + const float4 exp_psrc4 = exp((psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)) - max); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + pdst4[i00] /= sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl b/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl new file mode 100644 index 000000000..d562774ea --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl @@ -0,0 +1,87 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_soft_max_4( + global float * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0; + global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float4 lmax4 = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); + } + float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3)); + + const float max = sub_group_reduce_max(lmax); + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + pdst4[i00] /= sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/softmax_f16.cl b/ggml/src/ggml-opencl/kernels/softmax_f16.cl new file mode 100644 index 000000000..d38d09967 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/softmax_f16.cl @@ -0,0 +1,86 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_soft_max_f16( + global float * src0, + ulong offset0, + global half * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float *)((global char *)src0 + offset0); + src1 = (global half *)((global char *)src1 + offset1); + dst = (global float *)((global char *)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + global half * pmask = (global char *)src1 != (global char *)src0 ? src1 + i01*ne00 : 0; + global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float lmax = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); + } + float max = sub_group_reduce_max(lmax); + + // parallel sum + float lsum = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); + lsum += exp_psrc0; + // Remember the result of exp here. exp is expensive, so we really do not + // wish to compute it twice. + pdst[i00] = exp_psrc0; + } + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + pdst[i00] /= sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/softmax_f32.cl b/ggml/src/ggml-opencl/kernels/softmax_f32.cl new file mode 100644 index 000000000..001b587ab --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/softmax_f32.cl @@ -0,0 +1,86 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_soft_max( + global float * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0; + global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float lmax = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); + } + float max = sub_group_reduce_max(lmax); + + // parallel sum + float lsum = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); + lsum += exp_psrc0; + // Remember the result of exp here. exp is expensive, so we really do not + // wish to compute it twice. + pdst[i00] = exp_psrc0; + } + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + pdst[i00] /= sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/sub.cl b/ggml/src/ggml-opencl/kernels/sub.cl new file mode 100644 index 000000000..041e88ad3 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/sub.cl @@ -0,0 +1,72 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// div +//------------------------------------------------------------------------------ +kernel void kernel_sub( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int i13 = i03 % ne13; + int i12 = i02 % ne12; + int i11 = i01 % ne11; + + global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const int i10 = i0 % ne10; + *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) - *((global float *)(src1_ptr + i10*nb10)); + } +} + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_sub_row( + global float4 * src0, + ulong offset0, + global float4 * src1, + ulong offset1, + global float4 * dst, + ulong offsetd, + int ne +) { + src0 = (global float4*)((global char*)src0 + offset0); + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float4*)((global char*)dst + offsetd); + + // This performs better than using %. + uint gid = get_global_id(0); + uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne + dst[gid] = src0[gid] - src1[idx1]; +} diff --git a/ggml/src/ggml-opencl/kernels/sum_rows.cl b/ggml/src/ggml-opencl/kernels/sum_rows.cl new file mode 100644 index 000000000..c5f7c570f --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/sum_rows.cl @@ -0,0 +1,39 @@ + +kernel void kernel_sum_rows_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global float *)((global char *)src0 + offset0); + dst = (global float *)((global char *)dst + offsetd); + + int i3 = get_global_id(2); + int i2 = get_global_id(1); + int i1 = get_global_id(0); + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); + global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + + float row_sum = 0; + + for (int i0 = 0; i0 < ne00; i0++) { + row_sum += src_row[i0]; + } + + dst_row[0] = row_sum; +} diff --git a/ggml/src/ggml-opencl/kernels/tanh.cl b/ggml/src/ggml-opencl/kernels/tanh.cl new file mode 100644 index 000000000..d9da86b14 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/tanh.cl @@ -0,0 +1,63 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +kernel void kernel_tanh_f32_nd( + global void * p_src0_base, ulong off_src0_abs, + global void * p_dst_base, ulong off_dst_abs, + int ne00, int ne01, int ne02, int ne03, + ulong nb00, ulong nb01, ulong nb02, ulong nb03, + int ne10, int ne11, int ne12, int ne13, + ulong nb10, ulong nb11, ulong nb12, ulong nb13 +) { + int i0 = get_global_id(0); + int i1 = get_global_id(1); + int i2 = get_global_id(2); + + if (i0 < ne10 && i1 < ne11 && i2 < ne12) { + for (int i3 = 0; i3 < ne13; ++i3) { + ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; + global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor); + + ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; + global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); + + *dst_val_ptr = tanh(*src_val_ptr); + } + } +} + +kernel void kernel_tanh_f16_nd( + global void * p_src0_base, ulong off_src0_abs, + global void * p_dst_base, ulong off_dst_abs, + int ne00, int ne01, int ne02, int ne03, + ulong nb00, ulong nb01, ulong nb02, ulong nb03, + int ne10, int ne11, int ne12, int ne13, + ulong nb10, ulong nb11, ulong nb12, ulong nb13 +) { + int i0 = get_global_id(0); + int i1 = get_global_id(1); + int i2 = get_global_id(2); + + if (i0 < ne10 && i1 < ne11 && i2 < ne12) { + for (int i3 = 0; i3 < ne13; ++i3) { + ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; + global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor); + + ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; + global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); + + *dst_val_ptr = tanh(*src_val_ptr); + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/transpose.cl b/ggml/src/ggml-opencl/kernels/transpose.cl new file mode 100644 index 000000000..a11490b30 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/transpose.cl @@ -0,0 +1,84 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +// 16-bit transpose, loading/storing a 4x4 tile of elements +kernel void kernel_transpose_16( + __read_only image1d_buffer_t input, + __write_only image1d_buffer_t output, + const uint rows, + const uint cols +) { + + const int i = get_global_id(0); + const int j = get_global_id(1); + const int i_2 = i<<2; + const int j_2 = j<<2; + + half4 temp0 = read_imageh(input, (j_2+0)*cols+i); + half4 temp1 = read_imageh(input, (j_2+1)*cols+i); + half4 temp2 = read_imageh(input, (j_2+2)*cols+i); + half4 temp3 = read_imageh(input, (j_2+3)*cols+i); + + write_imageh(output, (i_2+0)*rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); + write_imageh(output, (i_2+1)*rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); + write_imageh(output, (i_2+2)*rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); + write_imageh(output, (i_2+3)*rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); +} + +// 32-bit transpose, loading/storing a 4x4 tile of elements +kernel void kernel_transpose_32( + __read_only image1d_buffer_t input, + __write_only image1d_buffer_t output, + const uint rows, + const uint cols +) { + + const int i = get_global_id(0); + const int j = get_global_id(1); + const int i_2 = i<<2; + const int j_2 = j<<2; + + float4 temp0 = read_imagef(input, (j_2+0)*cols+i); + float4 temp1 = read_imagef(input, (j_2+1)*cols+i); + float4 temp2 = read_imagef(input, (j_2+2)*cols+i); + float4 temp3 = read_imagef(input, (j_2+3)*cols+i); + + write_imagef(output, (i_2+0)*rows+j, (float4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); + write_imagef(output, (i_2+1)*rows+j, (float4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); + write_imagef(output, (i_2+2)*rows+j, (float4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); + write_imagef(output, (i_2+3)*rows+j, (float4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); + +} + +// 32-bit transpose, loading/storing a 4x4 tile of elements +// Only used for activations +// converts to FP16 +// also adds zero padding for non multiple of 8 prompt lengths +kernel void kernel_transpose_32_16(__read_only image1d_buffer_t input, __write_only image1d_buffer_t output, const uint rows, const uint cols, const uint padded_rows) { + + const int i = get_global_id(0); + const int j = get_global_id(1); + const int i_2 = i<<2; + const int j_2 = j<<2; + half4 temp0 = {0,0,0,0}; // initialize outputs to 0 + half4 temp1 = {0,0,0,0}; + half4 temp2 = {0,0,0,0}; + half4 temp3 = {0,0,0,0}; + + if((j_2+0)*cols+i*4+3 < rows*cols*16){ // only load from a valid location. Otherwise keep register data as 0 + temp0 = read_imageh(input, (j_2+0)*cols+i); + } + if((j_2+1)*cols+i*4+3 < rows*cols*16){ + temp1 = read_imageh(input, (j_2+1)*cols+i); + } + if((j_2+2)*cols+i*4+3 < rows*cols*16){ + temp2 = read_imageh(input, (j_2+2)*cols+i); + } + if((j_2+3)*cols+i*4+3 < rows*cols*16){ + temp3 = read_imageh(input, (j_2+3)*cols+i); + } + + write_imageh(output, (i_2+0)*padded_rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); // no conditionals for output, includes zero padding + write_imageh(output, (i_2+1)*padded_rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); + write_imageh(output, (i_2+2)*padded_rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); + write_imageh(output, (i_2+3)*padded_rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); +} diff --git a/ggml/src/ggml-opencl/kernels/tsembd.cl b/ggml/src/ggml-opencl/kernels/tsembd.cl new file mode 100644 index 000000000..4b1107f70 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/tsembd.cl @@ -0,0 +1,48 @@ +kernel void kernel_timestep_embedding( + global const void * p_timesteps, + ulong off_timesteps, + global void * p_dst, + ulong off_dst, + int dst_nb1_bytes, + int logical_dim, + int max_period +) { + int local_i; + int local_j; + int local_half_dim; + float local_timestep_val; + float local_freq; + float local_arg; + global float * local_embed_data_ptr; + global const float * local_timesteps_input_ptr; + global float * local_dst_output_base_ptr; + + local_timesteps_input_ptr = (global const float *)((global char *)p_timesteps + off_timesteps); + local_dst_output_base_ptr = (global float *)((global char *)p_dst + off_dst); + + local_i = get_global_id(1); + local_j = get_global_id(0); + + local_half_dim = logical_dim / 2; + local_embed_data_ptr = (global float *)((global char *)local_dst_output_base_ptr + local_i * dst_nb1_bytes); + + if (logical_dim % 2 != 0 && local_j == ((logical_dim + 1) / 2)) { + local_embed_data_ptr[logical_dim] = 0.0f; + } + + if (local_j >= local_half_dim) { + return; + } + + local_timestep_val = local_timesteps_input_ptr[local_i]; + + if (local_half_dim == 0) { + local_freq = 1.0f; + } else { + local_freq = exp(-log((float)max_period) * (float)local_j / (float)local_half_dim); + } + + local_arg = local_timestep_val * local_freq; + local_embed_data_ptr[local_j] = cos(local_arg); + local_embed_data_ptr[local_j + local_half_dim] = sin(local_arg); +} diff --git a/ggml/src/ggml-opencl/kernels/upscale.cl b/ggml/src/ggml-opencl/kernels/upscale.cl new file mode 100644 index 000000000..219d31dbb --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/upscale.cl @@ -0,0 +1,121 @@ +kernel void kernel_upscale( + global const void * p_src0, + ulong off_src0, + global void * p_dst, + ulong off_dst, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + float sf0, + float sf1, + float sf2, + float sf3 +) { + global const char * src_base = (global const char *)p_src0 + off_src0; + global float * dst_base = (global float *)((global char *)p_dst + off_dst); + + int index = get_global_id(0); + int dst_total_elements = ne10 * ne11 * ne12 * ne13; + + if (index >= dst_total_elements) { + return; + } + + int i10 = index % ne10; + int i11 = (index / ne10) % ne11; + int i12 = (index / (ne10 * ne11)) % ne12; + int i13 = index / (ne10 * ne11 * ne12); + + int i00 = (int)(i10 / sf0); + int i01 = (int)(i11 / sf1); + int i02 = (int)(i12 / sf2); + int i03 = (int)(i13 / sf3); + + ulong offset_src_element = (ulong)i03 * nb03 + (ulong)i02 * nb02 + (ulong)i01 * nb01 + (ulong)i00 * nb00; + global const float * src_element_ptr = (global const float *)(src_base + offset_src_element); + + dst_base[index] = *src_element_ptr; +} + +kernel void kernel_upscale_bilinear( + global const void * p_src0, + ulong off_src0, + global void * p_dst, + ulong off_dst, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne00_src, + int ne01_src, + int ne10_dst, + int ne11_dst, + int ne12_dst, + int ne13_dst, + float sf0, + float sf1, + float sf2, + float sf3 +) { + global const char * src_base = (global const char *)p_src0 + off_src0; + global float * dst_base = (global float *)((global char *)p_dst + off_dst); + + int index = get_global_id(0); + int dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + + if (index >= dst_total_elements) { + return; + } + + int i10_dst = index % ne10_dst; + int i11_dst = (index / ne10_dst) % ne11_dst; + int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst; + int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst); + + int i02_src = (int)(i12_dst / sf2); + int i03_src = (int)(i13_dst / sf3); + + const float pixel_offset = 0.5f; + + float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset; + long y0_src = (long)floor(y_src_f); + long y1_src = y0_src + 1; + + y0_src = max(0L, min(y0_src, (long)ne01_src - 1)); + y1_src = max(0L, min(y1_src, (long)ne01_src - 1)); + + float dy = y_src_f - (float)y0_src; + dy = max(0.0f, min(dy, 1.0f)); + + float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset; + long x0_src = (long)floor(x_src_f); + long x1_src = x0_src + 1; + + x0_src = max(0L, min(x0_src, (long)ne00_src - 1)); + x1_src = max(0L, min(x1_src, (long)ne00_src - 1)); + + float dx = x_src_f - (float)x0_src; + dx = max(0.0f, min(dx, 1.0f)); + + global const float * p_a = (global const float *)(src_base + (ulong)x0_src * nb00 + (ulong)y0_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03); + global const float * p_b = (global const float *)(src_base + (ulong)x1_src * nb00 + (ulong)y0_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03); + global const float * p_c = (global const float *)(src_base + (ulong)x0_src * nb00 + (ulong)y1_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03); + global const float * p_d = (global const float *)(src_base + (ulong)x1_src * nb00 + (ulong)y1_src * nb01 + (ulong)i02_src * nb02 + (ulong)i03_src * nb03); + + const float val_a = *p_a; + const float val_b = *p_b; + const float val_c = *p_c; + const float val_d = *p_d; + + float result = val_a * (1.0f - dx) * (1.0f - dy) + + val_b * dx * (1.0f - dy) + + val_c * (1.0f - dx) * dy + + val_d * dx * dy; + + dst_base[index] = result; +} diff --git a/ggml/src/ggml-opt.cpp b/ggml/src/ggml-opt.cpp index 7c3e24103..a3c82d675 100644 --- a/ggml/src/ggml-opt.cpp +++ b/ggml/src/ggml-opt.cpp @@ -28,16 +28,19 @@ struct ggml_opt_dataset { }; struct ggml_opt_context { - ggml_backend_sched_t backend_sched = nullptr; - ggml_cgraph * allocated_graph = nullptr; - ggml_cgraph * allocated_graph_copy = nullptr; - struct ggml_context * ctx_static = nullptr; - struct ggml_context * ctx_static_cpu = nullptr; - struct ggml_context * ctx_compute = nullptr; - struct ggml_context * ctx_copy = nullptr; - ggml_backend_buffer_t buf_static = nullptr; - ggml_backend_buffer_t buf_static_cpu = nullptr; - std::mt19937 rng; + ggml_backend_sched_t backend_sched = nullptr; + ggml_cgraph * allocated_graph = nullptr; + ggml_cgraph * allocated_graph_copy = nullptr; + struct ggml_context * ctx_static = nullptr; + struct ggml_context * ctx_cpu = nullptr; + struct ggml_context * ctx_compute = nullptr; + struct ggml_context * ctx_copy = nullptr; + ggml_backend_buffer_t buf_static = nullptr; + ggml_backend_buffer_t buf_cpu = nullptr; + std::mt19937 rng; + enum ggml_opt_loss_type loss_type; + enum ggml_opt_build_type build_type; + enum ggml_opt_build_type build_type_alloc; struct ggml_tensor * inputs = nullptr; struct ggml_tensor * outputs = nullptr; @@ -50,6 +53,11 @@ struct ggml_opt_context { struct ggml_cgraph * gf = nullptr; struct ggml_cgraph * gb_grad = nullptr; struct ggml_cgraph * gb_opt = nullptr; + bool static_graphs = false; + bool eval_ready = false; + std::vector grad_accs; + std::vector grad_m; + std::vector grad_v; int64_t iter = 1; int32_t opt_period = 1; @@ -73,7 +81,13 @@ struct ggml_opt_result { // ====== Dataset ====== -ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label, int64_t ndata, int64_t ndata_shard) { +ggml_opt_dataset_t ggml_opt_dataset_init( + enum ggml_type type_data, + enum ggml_type type_label, + int64_t ne_datapoint, + int64_t ne_label, + int64_t ndata, + int64_t ndata_shard) { GGML_ASSERT(ne_datapoint > 0); GGML_ASSERT(ne_label >= 0); GGML_ASSERT(ndata > 0); @@ -92,11 +106,11 @@ ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label, result->ctx = ggml_init(params); } - result->data = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_datapoint, ndata); + result->data = ggml_new_tensor_2d(result->ctx, type_data, ne_datapoint, ndata); result->nbs_data = ggml_nbytes(result->data) * ndata_shard/ndata; if (ne_label > 0) { - result->labels = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_label, ndata); + result->labels = ggml_new_tensor_2d(result->ctx, type_label, ne_label, ndata); result->nbs_labels = ggml_nbytes(result->labels) * ndata_shard/ndata; } else { result->labels = nullptr; @@ -119,6 +133,10 @@ void ggml_opt_dataset_free(ggml_opt_dataset_t dataset) { delete dataset; } +int64_t ggml_opt_dataset_ndata(ggml_opt_dataset_t dataset) { + return dataset->ndata; +} + struct ggml_tensor * ggml_opt_dataset_data(ggml_opt_dataset_t dataset) { return dataset->data; } @@ -144,6 +162,8 @@ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor * GGML_ASSERT( data_batch && ggml_is_contiguous(data_batch)); GGML_ASSERT(!labels_batch || ggml_is_contiguous(labels_batch)); GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr)); + GGML_ASSERT( data_batch->type == dataset->data->type); + GGML_ASSERT(!labels_batch || labels_batch->type == dataset->labels->type); const size_t nb_data_batch = ggml_nbytes(data_batch); GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0); @@ -171,6 +191,31 @@ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor * } } +void ggml_opt_dataset_get_batch_host(ggml_opt_dataset_t dataset, void * data_batch, size_t nb_data_batch, void * labels_batch, int64_t ibatch) { + GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr)); + GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0); + + const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data; + + GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size())); + + for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) { + const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch]; + + const char * ptr_data = (const char *) dataset->data->data + ishard *dataset->nbs_data; + char * ptr_data_batch = (char *) data_batch + ishard_batch*dataset->nbs_data; + memcpy(ptr_data_batch, ptr_data, dataset->nbs_data); + + if (!labels_batch) { + continue; + } + + const char * ptr_labels = (const char *) dataset->labels->data + ishard *dataset->nbs_labels; + char * ptr_labels_batch = (char *) labels_batch + ishard_batch*dataset->nbs_labels; + memcpy(ptr_labels_batch, ptr_labels, dataset->nbs_labels); + } +} + // ====== Model / Context ====== struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata) { @@ -187,17 +232,18 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us return result; } +struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) { + return *((struct ggml_opt_optimizer_params *) userdata); +} + struct ggml_opt_params ggml_opt_default_params( ggml_backend_sched_t backend_sched, - struct ggml_context * ctx_compute, - struct ggml_tensor * inputs, - struct ggml_tensor * outputs, enum ggml_opt_loss_type loss_type) { return { /*backend_sched =*/ backend_sched, - /*ctx_compute =*/ ctx_compute, - /*inputs =*/ inputs, - /*logits =*/ outputs, + /*ctx_compute =*/ nullptr, + /*inputs =*/ nullptr, + /*logits =*/ nullptr, /*loss_type =*/ loss_type, /*build_type =*/ GGML_OPT_BUILD_TYPE_OPT, /*opt_period =*/ 1, @@ -266,195 +312,246 @@ static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * src) { return dst; } -static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph) { - GGML_ASSERT(graph); - if (opt_ctx->allocated_graph == graph) { - return; - } +static void ggml_opt_build(ggml_opt_context_t opt_ctx) { + GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc"); + GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically"); - ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph + const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD && + !(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1); - { - ggml_init_params params = { - /*.mem_size =*/ ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE, - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - ggml_free(opt_ctx->ctx_copy); - opt_ctx->ctx_copy = ggml_init(params); - } - - opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph); - - ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy); - opt_ctx->allocated_graph = graph; -} - -ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) { - ggml_opt_context_t result = new struct ggml_opt_context; - result->backend_sched = params.backend_sched; - result->ctx_compute = params.ctx_compute; - result->inputs = params.inputs; - result->outputs = params.outputs; - result->opt_period = params.opt_period; - result->get_opt_pars = params.get_opt_pars; - result->get_opt_pars_ud = params.get_opt_pars_ud; - - GGML_ASSERT(result->inputs->data && "the inputs must be allocated statically"); - GGML_ASSERT(result->opt_period >= 1); - - const bool accumulate = params.build_type == GGML_OPT_BUILD_TYPE_GRAD || - (params.build_type == GGML_OPT_BUILD_TYPE_OPT && result->opt_period > 1); - - ggml_set_input(result->inputs); - ggml_set_output(result->outputs); - - result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass. - ggml_build_forward_expand(result->gf, result->outputs); + ggml_set_input(opt_ctx->inputs); + ggml_set_output(opt_ctx->outputs); int n_param = 0; - for (int i = 0; i < result->gf->n_nodes; ++i) { - if (result->gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) { + for (int i = 0; i < opt_ctx->gf->n_nodes; ++i) { + const struct ggml_tensor * node = opt_ctx->gf->nodes[i]; + if (node->flags & GGML_TENSOR_FLAG_PARAM) { n_param++; } + GGML_ASSERT(!(node->flags & GGML_TENSOR_FLAG_LOSS) && "support for extra loss terms not implemented"); } - { + if (!opt_ctx->ctx_static) { // The static context is used for: - // - gradients (1 tensor per param if using gradient accumulation) + // - gradients (1 per loss, 1 tensor per param if using gradient accumulation) // - optimizer momenta (2 tensors per param) - // - labels - // - loss + its gradient (up to 5 tensors) - // - pred - // - ncorrect (2 tensors). - const size_t tensors_per_param = (accumulate ? 1 : 0) + (params.build_type == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0); - const size_t size_meta = (tensors_per_param*n_param + 9) * ggml_tensor_overhead(); + // - labels (if using static graphs) + // - loss (if using static graphs, up to 5 tensors) + // - pred (if using static graphs) + // - ncorrect (if using static graphs, 2 tensors). + constexpr size_t n_loss = 1; + const size_t tensors_per_param = (accumulate ? 1 : 0) + + (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0); + const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0; + const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead(); struct ggml_init_params params = { /*.mem_size =*/ size_meta, /*.mem_buffer =*/ nullptr, /*.no_alloc =*/ true, }; - result->ctx_static = ggml_init(params); + opt_ctx->ctx_static = ggml_init(params); } + GGML_ASSERT(opt_ctx->build_type <= opt_ctx->build_type_alloc); + { - // The static cpu context is used for: - // - optimizer parameters (1 for the entire context) + // The cpu context is allocated statically if using static graphs, dynamically otherwise. + // It is used for: + // - optimizer parameters (1 shared for all optimizer invocations) const size_t size_meta = 1 * ggml_tensor_overhead(); struct ggml_init_params params = { /*.mem_size =*/ size_meta, /*.mem_buffer =*/ nullptr, /*.no_alloc =*/ true, }; - result->ctx_static_cpu = ggml_init(params); + ggml_free(opt_ctx->ctx_cpu); + opt_ctx->ctx_cpu = ggml_init(params); + + ggml_backend_buffer_free(opt_ctx->buf_cpu); + opt_ctx->buf_cpu = nullptr; } + struct ggml_context * ctx_results = opt_ctx->static_graphs ? opt_ctx->ctx_static : opt_ctx->ctx_compute; - switch (params.loss_type) { + switch (opt_ctx->loss_type) { case GGML_OPT_LOSS_TYPE_MEAN: { - result->loss = ggml_sum(result->ctx_static, result->outputs); - ggml_set_name(result->loss, "loss_sum"); - const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs)); - result->loss = ggml_scale(result->ctx_static, result->loss, scale); - ggml_set_name(result->loss, "loss_mean"); - result->loss_per_datapoint = true; + opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs); + ggml_set_name(opt_ctx->loss, "loss_sum"); + const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs)); + opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale); + ggml_set_name(opt_ctx->loss, "loss_mean"); + opt_ctx->loss_per_datapoint = true; break; } case GGML_OPT_LOSS_TYPE_SUM: { - result->loss = ggml_sum(result->ctx_static, result->outputs); - ggml_set_name(result->loss, "loss_sum"); - result->loss_per_datapoint = false; + opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->outputs); + ggml_set_name(opt_ctx->loss, "loss_sum"); + opt_ctx->loss_per_datapoint = false; break; } case GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: { - result->labels = ggml_dup_tensor(result->ctx_static, result->outputs); - ggml_set_input(result->labels); - ggml_set_name(result->labels, "labels"); - result->loss = ggml_cross_entropy_loss(result->ctx_static, result->outputs, result->labels); - ggml_set_name(result->loss, "loss_cross_entropy"); - if (result->opt_period > 1) { - result->loss = ggml_scale(result->ctx_static, result->loss, 1.0f / result->opt_period); - ggml_set_name(result->loss, "loss_cross_entropy_scaled"); + opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs); + ggml_set_input(opt_ctx->labels); + ggml_set_name(opt_ctx->labels, "labels"); + opt_ctx->loss = ggml_cross_entropy_loss(ctx_results, opt_ctx->outputs, opt_ctx->labels); + ggml_set_name(opt_ctx->loss, "loss_cross_entropy"); + if (opt_ctx->opt_period > 1) { + opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, 1.0f / opt_ctx->opt_period); + ggml_set_name(opt_ctx->loss, "loss_cross_entropy_scaled"); } - result->loss_per_datapoint = true; + opt_ctx->loss_per_datapoint = true; break; } case GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: { - result->labels = ggml_dup_tensor(result->ctx_static, result->outputs); - ggml_set_input(result->labels); - ggml_set_name(result->labels, "labels"); - result->loss = ggml_sub(result->ctx_static, result->outputs, result->labels); - ggml_set_name(result->loss, "loss_error"); - result->loss = ggml_sqr(result->ctx_static, result->loss); - ggml_set_name(result->loss, "loss_squared_error"); - result->loss = ggml_sum(result->ctx_static, result->loss); - ggml_set_name(result->loss, "loss_sum_squared_error"); - const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs)); - result->loss = ggml_scale(result->ctx_static, result->loss, scale); - ggml_set_name(result->loss, "loss_mean_squared_error"); - result->loss_per_datapoint = true; + opt_ctx->labels = ggml_dup_tensor(ctx_results, opt_ctx->outputs); + ggml_set_input(opt_ctx->labels); + ggml_set_name(opt_ctx->labels, "labels"); + opt_ctx->loss = ggml_sub(ctx_results, opt_ctx->outputs, opt_ctx->labels); + ggml_set_name(opt_ctx->loss, "loss_error"); + opt_ctx->loss = ggml_sqr(ctx_results, opt_ctx->loss); + ggml_set_name(opt_ctx->loss, "loss_squared_error"); + opt_ctx->loss = ggml_sum(ctx_results, opt_ctx->loss); + ggml_set_name(opt_ctx->loss, "loss_sum_squared_error"); + const float scale = 1.0f / (opt_ctx->opt_period * ggml_nelements(opt_ctx->outputs)); + opt_ctx->loss = ggml_scale(ctx_results, opt_ctx->loss, scale); + ggml_set_name(opt_ctx->loss, "loss_mean_squared_error"); + opt_ctx->loss_per_datapoint = true; break; } } - ggml_set_output(result->loss); - ggml_set_loss(result->loss); - ggml_build_forward_expand(result->gf, result->loss); + ggml_set_output(opt_ctx->loss); + ggml_set_loss(opt_ctx->loss); + ggml_build_forward_expand(opt_ctx->gf, opt_ctx->loss); - result->pred = ggml_argmax(result->ctx_static, result->outputs); - ggml_set_name(result->pred, "pred"); - ggml_set_output(result->pred); - ggml_build_forward_expand(result->gf, result->pred); + if (opt_ctx->loss_type == GGML_OPT_LOSS_TYPE_CROSS_ENTROPY) { + opt_ctx->pred = ggml_argmax(ctx_results, opt_ctx->outputs); + ggml_set_name(opt_ctx->pred, "pred"); + ggml_set_output(opt_ctx->pred); + ggml_build_forward_expand(opt_ctx->gf, opt_ctx->pred); - if (result->labels) { - result->ncorrect = ggml_count_equal(result->ctx_static, result->pred, ggml_argmax(result->ctx_static, result->labels)); - ggml_set_name(result->ncorrect, "ncorrect"); - ggml_set_output(result->ncorrect); - ggml_build_forward_expand(result->gf, result->ncorrect); - } else { - result->ncorrect = nullptr; + opt_ctx->ncorrect = ggml_count_equal(ctx_results, opt_ctx->pred, ggml_argmax(ctx_results, opt_ctx->labels)); + ggml_set_name(opt_ctx->ncorrect, "ncorrect"); + ggml_set_output(opt_ctx->ncorrect); + ggml_build_forward_expand(opt_ctx->gf, opt_ctx->ncorrect); } - if (params.build_type == GGML_OPT_BUILD_TYPE_FORWARD) { - result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0)); - return result; + if (opt_ctx->buf_static) { + if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_FORWARD) { + return; + } + } else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_FORWARD) { + opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors( + opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0)); + return; } - // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients. - result->gb_grad = ggml_graph_dup(result->ctx_compute, result->gf); - ggml_build_backward_expand(result->ctx_static, result->ctx_compute, result->gb_grad, accumulate); + if (opt_ctx->grad_accs.empty()) { + GGML_ASSERT(opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD); - if (params.build_type == GGML_OPT_BUILD_TYPE_GRAD) { - result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0)); - ggml_graph_reset(result->gb_grad); - return result; - } + const int n_nodes = opt_ctx->gf->n_nodes; + opt_ctx->grad_accs.resize(n_nodes); + for (int i = 0; i < n_nodes; ++i) { + ggml_tensor * node = opt_ctx->gf->nodes[i]; + if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) { + opt_ctx->grad_accs[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne); + } else { + opt_ctx->grad_accs[i] = nullptr; + } + } - GGML_ASSERT(params.build_type == GGML_OPT_BUILD_TYPE_OPT); - - // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step. - result->gb_opt = ggml_graph_dup(result->ctx_compute, result->gb_grad); - - result->adamw_params = ggml_new_tensor_1d(result->ctx_static_cpu, GGML_TYPE_F32, 7); - ggml_set_input(result->adamw_params); - ggml_set_name(result->adamw_params, "adamw_params"); - - for (int i = result->gf->n_nodes-1; i >= 0; --i) { - struct ggml_tensor * node = result->gb_opt->nodes[i]; - struct ggml_tensor * grad = ggml_graph_get_grad(result->gb_opt, node); - - if (node->flags & GGML_TENSOR_FLAG_PARAM) { - struct ggml_tensor * m = ggml_dup_tensor(result->ctx_static, node); - struct ggml_tensor * v = ggml_dup_tensor(result->ctx_static, node); - struct ggml_tensor * opt_step = ggml_opt_step_adamw(result->ctx_compute, node, grad, m, v, result->adamw_params); - ggml_build_forward_expand(result->gb_opt, opt_step); + if (opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) { + opt_ctx->grad_m.resize(n_nodes); + opt_ctx->grad_v.resize(n_nodes); + for (int i = 0; i < n_nodes; ++i) { + ggml_tensor * node = opt_ctx->gf->nodes[i]; + if (node->flags & GGML_TENSOR_FLAG_PARAM) { + opt_ctx->grad_m[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne); + opt_ctx->grad_v[i] = ggml_new_tensor(opt_ctx->ctx_static, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne); + } else { + opt_ctx->grad_m[i] = nullptr; + opt_ctx->grad_v[i] = nullptr; + } + } } } - result->buf_static = ggml_backend_alloc_ctx_tensors( - result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0)); + // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients. + opt_ctx->gb_grad = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gf, /*force_grads =*/ true); + ggml_build_backward_expand(opt_ctx->ctx_compute, opt_ctx->gb_grad, opt_ctx->grad_accs.data()); - result->buf_static_cpu = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx_static_cpu, ggml_backend_cpu_buffer_type()); + if (opt_ctx->buf_static) { + if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_GRAD) { + return; + } + } else if (opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_GRAD) { + opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors(opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0)); + ggml_graph_reset(opt_ctx->gb_grad); + } - ggml_graph_reset(result->gb_opt); + GGML_ASSERT(opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT); + + // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step. + opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true); + + opt_ctx->adamw_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, 7); + ggml_set_input(opt_ctx->adamw_params); + ggml_set_name(opt_ctx->adamw_params, "adamw_params"); + + for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) { + struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i]; + struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node); + + if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) { + struct ggml_tensor * m = opt_ctx->grad_m[i]; + struct ggml_tensor * v = opt_ctx->grad_v[i]; + struct ggml_tensor * opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, opt_ctx->adamw_params); + + ggml_set_name(m, (std::string("AdamW m for ") + std::string(node->name)).c_str()); + ggml_set_name(v, (std::string("AdamW v for ") + std::string(node->name)).c_str()); + ggml_set_name(opt_step, (std::string("AdamW step for ") + std::string(node->name)).c_str()); + + ggml_build_forward_expand(opt_ctx->gb_opt, opt_step); + } + } + + if (!opt_ctx->buf_static) { + opt_ctx->buf_static = ggml_backend_alloc_ctx_tensors( + opt_ctx->ctx_static, ggml_backend_sched_get_backend(opt_ctx->backend_sched, 0)); + ggml_graph_reset(opt_ctx->gb_opt); + } + + opt_ctx->buf_cpu = ggml_backend_alloc_ctx_tensors_from_buft(opt_ctx->ctx_cpu, ggml_backend_cpu_buffer_type()); +} + +ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) { + ggml_opt_context_t result = new struct ggml_opt_context; + result->backend_sched = params.backend_sched; + result->ctx_compute = params.ctx_compute; + result->loss_type = params.loss_type; + result->build_type = params.build_type; + result->build_type_alloc = params.build_type; + result->inputs = params.inputs; + result->outputs = params.outputs; + result->opt_period = params.opt_period; + result->get_opt_pars = params.get_opt_pars; + result->get_opt_pars_ud = params.get_opt_pars_ud; + + GGML_ASSERT(result->opt_period >= 1); + + result->static_graphs = result->ctx_compute; + + if (!result->static_graphs) { + GGML_ASSERT(!result->inputs); + GGML_ASSERT(!result->outputs); + return result; + } + + GGML_ASSERT(result->inputs); + GGML_ASSERT(result->outputs); + + result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass. + ggml_build_forward_expand(result->gf, result->outputs); + + ggml_opt_build(result); return result; } @@ -464,9 +561,9 @@ void ggml_opt_free(ggml_opt_context_t opt_ctx) { return; } ggml_backend_buffer_free(opt_ctx->buf_static); - ggml_backend_buffer_free(opt_ctx->buf_static_cpu); + ggml_backend_buffer_free(opt_ctx->buf_cpu); ggml_free(opt_ctx->ctx_static); - ggml_free(opt_ctx->ctx_static_cpu); + ggml_free(opt_ctx->ctx_cpu); delete opt_ctx; } @@ -479,6 +576,10 @@ void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer) { } } +bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx) { + return opt_ctx->static_graphs; +} + struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) { return opt_ctx->inputs; } @@ -582,8 +683,79 @@ void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, doubl // ====== Computation ====== -static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, ggml_opt_result * result) { - if (graph != opt_ctx->gf) { +void ggml_opt_prepare_alloc( + ggml_opt_context_t opt_ctx, + struct ggml_context * ctx_compute, + struct ggml_cgraph * gf, + struct ggml_tensor * inputs, + struct ggml_tensor * outputs) { + GGML_ASSERT(!opt_ctx->static_graphs); + opt_ctx->ctx_compute = ctx_compute; + opt_ctx->gf = gf; + opt_ctx->inputs = inputs; + opt_ctx->outputs = outputs; +} + +void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) { + GGML_ASSERT(!opt_ctx->eval_ready); + if (opt_ctx->build_type == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period > 1 && opt_ctx->opt_i == 0) { + ggml_graph_reset(opt_ctx->gb_grad); + } + if (backward) { + const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period; + opt_ctx->build_type = opt_i_next == 0 ? GGML_OPT_BUILD_TYPE_OPT : GGML_OPT_BUILD_TYPE_GRAD; + } else { + opt_ctx->build_type = GGML_OPT_BUILD_TYPE_FORWARD; + } + + if (!opt_ctx->static_graphs) { + ggml_opt_build(opt_ctx); + } + + struct ggml_cgraph * graph = nullptr; + switch (opt_ctx->build_type) { + case GGML_OPT_BUILD_TYPE_FORWARD: { + graph = opt_ctx->gf; + } break; + case GGML_OPT_BUILD_TYPE_GRAD: { + graph = opt_ctx->gb_grad; + } break; + case GGML_OPT_BUILD_TYPE_OPT: { + graph = opt_ctx->gb_opt; + } break; + } + GGML_ASSERT(graph); + + if (opt_ctx->allocated_graph == graph) { + opt_ctx->eval_ready = true; + return; + } + + ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph + + if (opt_ctx->static_graphs) { + ggml_init_params params = { + /*.mem_size =*/ graph->size*ggml_tensor_overhead() + ggml_graph_overhead_custom(graph->size, graph->grads), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + ggml_free(opt_ctx->ctx_copy); + opt_ctx->ctx_copy = ggml_init(params); + + opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph); + } else { + opt_ctx->allocated_graph_copy = graph; + } + + ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy); + opt_ctx->allocated_graph = graph; + + opt_ctx->eval_ready = true; +} + +void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) { + GGML_ASSERT(opt_ctx->eval_ready); + if (opt_ctx->allocated_graph == opt_ctx->gb_opt) { struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud); GGML_ASSERT(opt_pars.adamw.alpha > 0.0f); @@ -609,9 +781,19 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, adamw_par_data[6] = beta2h; } - ggml_opt_alloc_graph(opt_ctx, graph); ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy); opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt; + opt_ctx->opt_i = (opt_ctx->opt_i + 1) % opt_ctx->opt_period; + + if (!opt_ctx->static_graphs) { + opt_ctx->gf = nullptr; + opt_ctx->gb_grad = nullptr; + opt_ctx->gb_opt = nullptr; + opt_ctx->allocated_graph = nullptr; + opt_ctx->allocated_graph_copy = nullptr; + } + + opt_ctx->eval_ready = false; if (!result) { return; @@ -635,12 +817,14 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, ggml_nbytes(opt_ctx->loss)); result->loss.push_back(loss); - GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32); - std::vector pred(ndata); - ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred)); - result->pred.insert(result->pred.end(), pred.begin(), pred.end()); + if (opt_ctx->pred) { + GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32); + std::vector pred(ndata); + ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred)); + result->pred.insert(result->pred.end(), pred.begin(), pred.end()); + } - if (!opt_ctx->labels || result->ncorrect < 0) { + if (!opt_ctx->ncorrect || result->ncorrect < 0) { result->ncorrect = -1; return; } @@ -652,26 +836,6 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, result->ncorrect += ncorrect; } -void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) { - ggml_opt_eval_graph(opt_ctx, opt_ctx->gf, result); -} - -void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) { - if (opt_ctx->opt_period == 1) { - ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result); - return; - } - - const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period; - if (opt_i_next == 0) { - ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result); - ggml_opt_reset(opt_ctx, /*optimizer =*/ false); - } else { - ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_grad, result); - } - opt_ctx->opt_i = opt_i_next; -} - // ====== High-Level Functions ====== void ggml_opt_epoch( @@ -682,6 +846,7 @@ void ggml_opt_epoch( int64_t idata_split, ggml_opt_epoch_callback callback_train, ggml_opt_epoch_callback callback_eval) { + GGML_ASSERT(ggml_opt_static_graphs(opt_ctx) && "ggml_opt_epoch requires static graphs"); struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx); struct ggml_tensor * labels = ggml_opt_labels(opt_ctx); struct ggml_tensor * data = ggml_opt_dataset_data(dataset); @@ -700,16 +865,18 @@ void ggml_opt_epoch( int64_t ibatch = 0; int64_t t_loop_start = ggml_time_us(); for (; ibatch < ibatch_split; ++ibatch) { + ggml_opt_alloc(opt_ctx, /*backward =*/ true); ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch); - ggml_opt_forward_backward(opt_ctx, result_train); + ggml_opt_eval(opt_ctx, result_train); if (callback_train) { callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start); } } t_loop_start = ggml_time_us(); for (; ibatch < nbatches; ++ibatch) { + ggml_opt_alloc(opt_ctx, /*backward =*/ false); ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch); - ggml_opt_forward(opt_ctx, result_eval); + ggml_opt_eval(opt_ctx, result_eval); if (callback_eval) { callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start); } @@ -726,13 +893,26 @@ void ggml_opt_epoch_callback_progress_bar( int64_t t_start_us) { fprintf(stderr, "%s[", train ? "train: " : "val: "); - constexpr int64_t bar_length = 25; + // The progress bar consists of partially filled blocks, unicode has 8 separate fill levels. + constexpr int64_t bar_length = 8; + const int64_t ibatch8 = 8 * ibatch; for (int64_t j = 0; j < bar_length; ++j) { - const int64_t ibatch_j = ibatch_max * j/bar_length; - if (ibatch_j < ibatch) { - fprintf(stderr, "="); - } else if (ibatch_max * (j - 1)/bar_length < ibatch) { - fprintf(stderr, ">"); + if (ibatch_max * (8*j + 8) / bar_length < ibatch8) { + fprintf(stderr, "\u2588"); // full block + } else if (ibatch_max * (8*j + 7) / bar_length < ibatch8) { + fprintf(stderr, "\u2589"); // 7/8 filled + } else if (ibatch_max * (8*j + 6) / bar_length < ibatch8) { + fprintf(stderr, "\u258A"); // 6/8 filled + } else if (ibatch_max * (8*j + 5) / bar_length < ibatch8) { + fprintf(stderr, "\u258B"); // 5/8 filled + } else if (ibatch_max * (8*j + 4) / bar_length < ibatch8) { + fprintf(stderr, "\u258C"); // 4/8 filled + } else if (ibatch_max * (8*j + 3) / bar_length < ibatch8) { + fprintf(stderr, "\u258D"); // 3/8 filled + } else if (ibatch_max * (8*j + 2) / bar_length < ibatch8) { + fprintf(stderr, "\u258E"); // 2/8 filled + } else if (ibatch_max * (8*j + 1) / bar_length < ibatch8) { + fprintf(stderr, "\u258F"); // 1/8 filled } else { fprintf(stderr, " "); } @@ -764,8 +944,8 @@ void ggml_opt_epoch_callback_progress_bar( const int64_t t_eta_m = t_eta_s / 60; t_eta_s -= t_eta_m * 60; - fprintf(stderr, "| data=%06" PRId64 "/%06" PRId64 ", loss=%.6lf+-%.6lf, accuracy=%.2lf+-%.2lf%%, " - "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 ", ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 "]\r", + fprintf(stderr, "] data=%07" PRId64 "/%07" PRId64 " loss=%.5lf±%.5lf acc=%.2lf±%.2lf%% " + "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 " \r", idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc, t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s); if (ibatch == ibatch_max) { @@ -806,7 +986,10 @@ void ggml_opt_fit( int64_t epoch = 1; - ggml_opt_params params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type); + ggml_opt_params params = ggml_opt_default_params(backend_sched, loss_type); + params.ctx_compute = ctx_compute; + params.inputs = inputs; + params.outputs = outputs; params.opt_period = opt_period; params.get_opt_pars = get_opt_pars; params.get_opt_pars_ud = &epoch; diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index ac918a60d..e389a46db 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -19,12 +19,6 @@ #define GROUP_MAX_EPS_IQ1_M 1e-7f #define GROUP_MAX_EPS_IQ1_S 1e-12f -#if defined(_MSC_VER) -// disable "possible loss of data" to avoid warnings for hundreds of casts -// we should just be careful :) -#pragma warning(disable: 4244 4267) -#endif - #define UNUSED GGML_UNUSED // reference implementation for deterministic creation of model files @@ -2431,8 +2425,6 @@ void dequantize_row_iq1_m(const block_iq1_m * GGML_RESTRICT x, float * GGML_REST } } -static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; - void dequantize_row_iq4_nl(const block_iq4_nl * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { assert(k % QK4_NL == 0); const int64_t nb = k / QK4_NL; diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 862b9b666..f468f796d 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -1,6 +1,7 @@ #include "ggml-rpc.h" #include "ggml-impl.h" #include "ggml-backend-impl.h" +#include "ggml-cpp.h" #include #include @@ -52,6 +53,9 @@ struct socket_t { } }; +// macro for nicer error messages on server crash +#define RPC_STATUS_ASSERT(x) if (!(x)) GGML_ABORT("Remote RPC server crashed or returned malformed response") + // all RPC structures must be packed #pragma pack(push, 1) // ggml_tensor is serialized into rpc_tensor @@ -91,12 +95,19 @@ enum rpc_cmd { RPC_CMD_GET_DEVICE_MEMORY, RPC_CMD_INIT_TENSOR, RPC_CMD_GET_ALLOC_SIZE, + RPC_CMD_HELLO, RPC_CMD_COUNT, }; // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold const size_t HASH_THRESHOLD = 10 * 1024 * 1024; +struct rpc_msg_hello_rsp { + uint8_t major; + uint8_t minor; + uint8_t patch; +}; + struct rpc_msg_get_alloc_size_req { rpc_tensor tensor; }; @@ -143,6 +154,12 @@ struct rpc_msg_buffer_clear_req { uint8_t value; }; +struct rpc_msg_set_tensor_hash_req { + rpc_tensor tensor; + uint64_t offset; + uint64_t hash; +}; + struct rpc_msg_set_tensor_hash_rsp { uint8_t result; }; @@ -370,8 +387,8 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int } // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | -// RPC response: | response_size (8 bytes) | response_data (response_size bytes) | -static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) { +// No response +static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const void * input, size_t input_size) { uint8_t cmd_byte = cmd; if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) { return false; @@ -382,6 +399,15 @@ static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cm if (!send_data(sock->fd, input, input_size)) { return false; } + return true; +} + +// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | +// RPC response: | response_size (8 bytes) | response_data (response_size bytes) | +static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) { + if (!send_rpc_cmd(sock, cmd, input, input_size)) { + return false; + } // TODO: currently the output_size is always known, do we need support for commands with variable output size? // even if we do, we can skip sending output_size from the server for commands with known output size uint64_t out_size; @@ -399,6 +425,20 @@ static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cm // RPC client-side implementation +static bool check_server_version(const std::shared_ptr & sock) { + rpc_msg_hello_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response)); + RPC_STATUS_ASSERT(status); + if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) { + fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch); + return false; + } + if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) { + fprintf(stderr, "WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch); + } + return true; +} + static std::shared_ptr get_socket(const std::string & endpoint) { static std::mutex mutex; std::lock_guard lock(mutex); @@ -432,6 +472,9 @@ static std::shared_ptr get_socket(const std::string & endpoint) { if (sock == nullptr) { return nullptr; } + if (!check_server_version(sock)) { + return nullptr; + } GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd); sockets[endpoint] = sock; return sock; @@ -441,7 +484,7 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; rpc_msg_free_buffer_req request = {ctx->remote_ptr}; bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); delete ctx; } @@ -453,7 +496,7 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) { rpc_msg_buffer_get_base_req request = {ctx->remote_ptr}; rpc_msg_buffer_get_base_rsp response; bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response)); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); ctx->base_ptr = reinterpret_cast(response.base_ptr); return ctx->base_ptr; } @@ -484,6 +527,11 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { result.view_src = reinterpret_cast(tensor->view_src); result.view_offs = tensor->view_offs; result.data = reinterpret_cast(tensor->data); + + // Avoid sending uninitialized data over the wire + memset(result.name, 0, sizeof(result.name)); + memset(result.padding, 0, sizeof(result.padding)); + snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name); return result; } @@ -500,7 +548,7 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_ request.tensor = serialize_tensor(tensor); bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); } return GGML_STATUS_SUCCESS; } @@ -509,16 +557,13 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; rpc_tensor rpc_tensor = serialize_tensor(tensor); if (size > HASH_THRESHOLD) { - // input serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes) - size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + sizeof(uint64_t); - std::vector input(input_size, 0); - uint64_t hash = fnv_hash((const uint8_t*)data, size); - memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor)); - memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); - memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &hash, sizeof(hash)); + rpc_msg_set_tensor_hash_req request; + request.tensor = rpc_tensor; + request.offset = offset; + request.hash = fnv_hash((const uint8_t*)data, size); rpc_msg_set_tensor_hash_rsp response; - bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, input.data(), input.size(), &response, sizeof(response)); - GGML_ASSERT(status); + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, &request, sizeof(request), &response, sizeof(response)); + RPC_STATUS_ASSERT(status); if (response.result) { // the server has the same data, no need to send it return; @@ -530,8 +575,8 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor)); memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size); - bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size(), nullptr, 0); - GGML_ASSERT(status); + bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size()); + RPC_STATUS_ASSERT(status); } static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { @@ -541,7 +586,7 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con request.offset = offset; request.size = size; bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); } static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { @@ -559,7 +604,7 @@ static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con request.dst = serialize_tensor(dst); rpc_msg_copy_tensor_rsp response; bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response)); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); return response.result; } @@ -567,7 +612,7 @@ static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value}; bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); } static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = { @@ -593,7 +638,7 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back rpc_msg_alloc_buffer_rsp response; auto sock = get_socket(buft_ctx->endpoint); bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response)); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); if (response.remote_ptr != 0) { ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft, ggml_backend_rpc_buffer_interface, @@ -608,7 +653,7 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back static size_t get_alignment(const std::shared_ptr & sock) { rpc_msg_get_alignment_rsp response; bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response)); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); return response.alignment; } @@ -620,7 +665,7 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ static size_t get_max_size(const std::shared_ptr & sock) { rpc_msg_get_max_size_rsp response; bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response)); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); return response.max_size; } @@ -641,7 +686,7 @@ static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_ty rpc_msg_get_alloc_size_rsp response; bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response)); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); return response.alloc_size; } else { @@ -719,7 +764,7 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g rpc_msg_graph_compute_rsp response; auto sock = get_socket(rpc_ctx->endpoint); bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response)); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); return (enum ggml_status)response.result; } @@ -793,7 +838,7 @@ bool ggml_backend_is_rpc(ggml_backend_t backend) { static void get_device_memory(const std::shared_ptr & sock, size_t * free, size_t * total) { rpc_msg_get_device_memory_rsp response; bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response)); - GGML_ASSERT(status); + RPC_STATUS_ASSERT(status); *free = response.free_mem; *total = response.total_mem; } @@ -817,6 +862,7 @@ public: } ~rpc_server(); + void hello(rpc_msg_hello_rsp & response); void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response); void get_alignment(rpc_msg_get_alignment_rsp & response); void get_max_size(rpc_msg_get_max_size_rsp & response); @@ -824,7 +870,7 @@ public: bool free_buffer(const rpc_msg_free_buffer_req & request); bool buffer_clear(const rpc_msg_buffer_clear_req & request); bool set_tensor(const std::vector & input); - bool set_tensor_hash(const std::vector & input, rpc_msg_set_tensor_hash_rsp & response); + bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response); bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector & response); bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response); bool graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response); @@ -845,6 +891,13 @@ private: std::unordered_set buffers; }; +void rpc_server::hello(rpc_msg_hello_rsp & response) { + response.major = RPC_PROTO_MAJOR_VERSION; + response.minor = RPC_PROTO_MINOR_VERSION; + response.patch = RPC_PROTO_PATCH_VERSION; + GGML_PRINT_DEBUG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch); +} + bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) { ggml_backend_buffer_type_t buft; struct ggml_init_params params { @@ -853,12 +906,13 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_ /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); if (tensor == nullptr) { GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n"); - ggml_free(ctx); return false; } @@ -871,7 +925,6 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_ response.alloc_size = ggml_backend_buft_get_alloc_size(buft,tensor); - ggml_free(ctx); return true; } @@ -940,8 +993,21 @@ bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) { } ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) { + // Validate tensor type before using it + if (tensor->type >= GGML_TYPE_COUNT) { + GGML_LOG_ERROR("[%s] invalid tensor type received: %u\n", __func__, tensor->type); + return nullptr; + } + ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + + // ggml_new_tensor_4d might fail if dimensions are invalid, although less likely to crash than invalid type + if (result == nullptr) { + GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\\n", __func__, tensor->type); + return nullptr; + } + for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) { result->nb[i] = tensor->nb[i]; } @@ -985,11 +1051,12 @@ bool rpc_server::set_tensor(const std::vector & input) { /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor); if (tensor == nullptr) { GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); - ggml_free(ctx); return false; } GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size); @@ -1000,7 +1067,9 @@ bool rpc_server::set_tensor(const std::vector & input) { const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer); if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) { - GGML_ABORT("[%s] tensor->data out of bounds\n", __func__); + GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu) out of buffer bounds [0x%zx, 0x%zx)\n", + __func__, in_tensor->data, offset, size, p0, p1); + return false; } } @@ -1016,7 +1085,6 @@ bool rpc_server::set_tensor(const std::vector & input) { printf("[%s] saved to '%s'\n", __func__, cache_file.c_str()); } ggml_backend_tensor_set(tensor, data, offset, size); - ggml_free(ctx); return true; } @@ -1039,18 +1107,10 @@ bool rpc_server::get_cached_file(uint64_t hash, std::vector & data) { return true; } -bool rpc_server::set_tensor_hash(const std::vector & input, rpc_msg_set_tensor_hash_rsp & response) +bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response) { - // serialization format: | rpc_tensor | offset (8 bytes) | hash (8 bytes) | - if (input.size() != sizeof(rpc_tensor) + 16) { - return false; - } - const rpc_tensor * in_tensor = (const rpc_tensor *)input.data(); - uint64_t offset; - memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset)); - const uint64_t * hash = (const uint64_t *)(input.data() + sizeof(rpc_tensor) + sizeof(offset)); std::vector cached_file; - if (!get_cached_file(*hash, cached_file)) { + if (!get_cached_file(request.hash, cached_file)) { response.result = 0; return true; } @@ -1060,27 +1120,32 @@ bool rpc_server::set_tensor_hash(const std::vector & input, rpc_msg_set /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); - ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); + ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); if (tensor == nullptr) { GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); - ggml_free(ctx); return false; } - GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size, *hash); + GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n", + __func__, (void*)tensor->buffer, tensor->data, request.offset, size, request.hash); // sanitize tensor->data { const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer); const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer); - if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) { - GGML_ABORT("[%s] tensor->data out of bounds\n", __func__); + if (request.tensor.data + request.offset < p0 + || request.tensor.data + request.offset >= p1 + || size > (p1 - request.tensor.data - request.offset)) { + GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu, hash=0x%" PRIx64 ") out of buffer bounds [0x%zx, 0x%zx)\n", + __func__, request.tensor.data, request.offset, size, request.hash, p0, p1); + return false; } } - ggml_backend_tensor_set(tensor, cached_file.data(), offset, size); + ggml_backend_tensor_set(tensor, cached_file.data(), request.offset, size); response.result = 1; - ggml_free(ctx); return true; } @@ -1090,11 +1155,12 @@ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) { /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); if (tensor == nullptr) { GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n"); - ggml_free(ctx); return false; } @@ -1110,11 +1176,9 @@ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) { // This pointer can either be passed around client/server, or probably better stored server-side and kept track of. // Currently unimplemented. GGML_LOG_ERROR("tensor->extra populated by the backend, this is currently unsupported.\n"); - ggml_free(ctx); return false; } - ggml_free(ctx); return true; } @@ -1124,11 +1188,12 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector< /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); if (tensor == nullptr) { GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); - ggml_free(ctx); return false; } GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size); @@ -1141,13 +1206,14 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector< if (request.tensor.data + request.offset < p0 || request.tensor.data + request.offset >= p1 || request.size > (p1 - request.tensor.data - request.offset)) { - GGML_ABORT("[%s] tensor->data out of bounds\n", __func__); + GGML_LOG_ERROR("[%s] requested tensor region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%" PRIu64 ") out of buffer bounds [0x%zx, 0x%zx)\n", + __func__, request.tensor.data, request.offset, request.size, p0, p1); + return false; } } response.resize(request.size, 0); ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size); - ggml_free(ctx); return true; } @@ -1157,12 +1223,14 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); + ggml_tensor * src = deserialize_tensor(ctx, &request.src); ggml_tensor * dst = deserialize_tensor(ctx, &request.dst); if (src == nullptr || dst == nullptr) { GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__); - ggml_free(ctx); return false; } @@ -1180,7 +1248,6 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co dst_data + src_size, dst_base, dst_base + dst_buf_sz); - ggml_free(ctx); return false; } @@ -1188,7 +1255,6 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co __func__, (void*) src->buffer, (void*) dst->buffer); response.result = ggml_backend_buffer_copy_tensor(src, dst); - ggml_free(ctx); return true; } @@ -1196,22 +1262,50 @@ ggml_tensor * rpc_server::create_node(uint64_t id, struct ggml_context * ctx, const std::unordered_map & tensor_ptrs, std::unordered_map & tensor_map) { - if (id == 0) { - return nullptr; - } if (tensor_map.find(id) != tensor_map.end()) { return tensor_map[id]; } - const rpc_tensor * tensor = tensor_ptrs.at(id); + // Safely find the tensor pointer + auto it_ptr = tensor_ptrs.find(id); + if (it_ptr == tensor_ptrs.end()) { + return nullptr; + } + const rpc_tensor * tensor = it_ptr->second; + struct ggml_tensor * result = deserialize_tensor(ctx, tensor); if (result == nullptr) { return nullptr; } tensor_map[id] = result; for (int i = 0; i < GGML_MAX_SRC; i++) { - result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map); + // Check if the source ID is 0 before calling create_node recursively + if (tensor->src[i] == 0) { + result->src[i] = nullptr; + } else { + result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map); + // If the recursive call failed for a non-zero ID, propagate the error + if (result->src[i] == nullptr) { + GGML_LOG_ERROR("[%s] failed to create source node %d (src_id=%" PRIu64 ") for node id %" PRIu64 "\n", + __func__, i, tensor->src[i], id); + // Must return nullptr to signal failure up the call stack + return nullptr; + } + } + } + + // Handle view_src similarly + if (tensor->view_src == 0) { + result->view_src = nullptr; + } else { + result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map); + // If the recursive call failed for a non-zero ID, propagate the error + if (result->view_src == nullptr) { + GGML_LOG_ERROR("[%s] failed to create view_src node (view_src_id=%" PRIu64 ") for node id %" PRIu64 "\n", + __func__, tensor->view_src, id); + // Must return nullptr to signal failure up the call stack + return nullptr; + } } - result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map); result->view_offs = tensor->view_offs; return result; } @@ -1237,12 +1331,15 @@ bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors); size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false); + struct ggml_init_params params = { /*.mem_size =*/ buf_size, /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - struct ggml_context * ctx = ggml_init(params); + ggml_context_ptr ctx_ptr { ggml_init(params) }; + GGML_ASSERT(ctx_ptr != nullptr); + ggml_context * ctx = ctx_ptr.get(); struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false); graph->n_nodes = n_nodes; std::unordered_map tensor_ptrs; @@ -1254,10 +1351,17 @@ bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph int64_t id; memcpy(&id, &nodes[i], sizeof(id)); graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map); + + // Check if create_node failed for a *non-zero* ID. + // If id was 0, create_node returning nullptr is expected. + // If id was non-zero and create_node returned nullptr, it indicates a deserialization error. + if (graph->nodes[i] == nullptr && id != 0) { + GGML_LOG_ERROR("[%s] failed to create graph node %d (id=%" PRId64 ")\n", __func__, i, id); + return false; + } } ggml_status status = ggml_backend_graph_compute(backend, graph); response.result = status; - ggml_free(ctx); return true; } @@ -1270,8 +1374,24 @@ rpc_server::~rpc_server() { static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, sockfd_t sockfd, size_t free_mem, size_t total_mem) { rpc_server server(backend, cache_dir); + uint8_t cmd; + if (!recv_data(sockfd, &cmd, 1)) { + return; + } + // the first command sent by the client must be HELLO + if (cmd != RPC_CMD_HELLO) { + fprintf(stderr, "Expected HELLO command, update client\n"); + return; + } + if (!recv_msg(sockfd, nullptr, 0)) { + return; + } + rpc_msg_hello_rsp response; + server.hello(response); + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } while (true) { - uint8_t cmd; if (!recv_data(sockfd, &cmd, 1)) { break; } @@ -1281,6 +1401,10 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, break; } switch (cmd) { + case RPC_CMD_HELLO: { + // HELLO command is handled above + return; + } case RPC_CMD_ALLOC_BUFFER: { rpc_msg_alloc_buffer_req request; if (!recv_msg(sockfd, &request, sizeof(request))) { @@ -1299,7 +1423,9 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, return; } rpc_msg_get_alloc_size_rsp response; - server.get_alloc_size(request, response); + if (!server.get_alloc_size(request, response)) { + return; + } if (!send_msg(sockfd, &response, sizeof(response))) { return; } @@ -1375,18 +1501,15 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, if (!server.set_tensor(input)) { return; } - if (!send_msg(sockfd, nullptr, 0)) { - return; - } break; } case RPC_CMD_SET_TENSOR_HASH: { - std::vector input; - if (!recv_msg(sockfd, input)) { + rpc_msg_set_tensor_hash_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { return; } rpc_msg_set_tensor_hash_rsp response; - if (!server.set_tensor_hash(input, response)) { + if (!server.set_tensor_hash(request, response)) { return; } if (!send_msg(sockfd, &response, sizeof(response))) { @@ -1472,6 +1595,14 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, const char * cache_dir, size_t free_mem, size_t total_mem) { + printf("Starting RPC server v%d.%d.%d\n", + RPC_PROTO_MAJOR_VERSION, + RPC_PROTO_MINOR_VERSION, + RPC_PROTO_PATCH_VERSION); + printf(" endpoint : %s\n", endpoint); + printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a"); + printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024)); + std::string host; int port; if (!parse_endpoint(endpoint, host, port)) { @@ -1631,6 +1762,9 @@ static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const ch if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) { return (void *)ggml_backend_rpc_add_device; } + if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) { + return (void *)ggml_backend_rpc_start_server; + } return NULL; GGML_UNUSED(reg); diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index 6699b70ba..efd78b912 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -13,7 +13,7 @@ elseif(SUPPORTS_SYCL) If you expected the oneAPI Release compiler, please install oneAPI & source it, like: source /opt/intel/oneapi/setvars.sh") else() - message(FATAL_ERROR, "C++ compiler lacks SYCL support.") + message(FATAL_ERROR "C++ compiler lacks SYCL support.") endif() message(STATUS "SYCL found") #todo: AOT @@ -49,35 +49,38 @@ endif() target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing") # Link against oneDNN -find_package(DNNL) set(GGML_SYCL_DNNL 0) -if(DNNL_FOUND) - if (DEFINED ENV{ONEAPI_ROOT} AND NOT DEFINED DNNL_GPU_VENDOR) - # Assuming oneDNN packaged with oneapi release is used which - # supports only intel target - set(DNNL_GPU_VENDOR "INTEL") - if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL") - message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target") +if(GGML_SYCL_DNN) + find_package(DNNL) + if(DNNL_FOUND) + if (NOT DEFINED DNNL_GPU_VENDOR) + # default to intel target + set(DNNL_GPU_VENDOR "INTEL") + if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL") + message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target") + endif() endif() - endif() - # Verify oneDNN was compiled for the same target as llama - if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}") - target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl) - set(GGML_SYCL_DNNL 1) - get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS) - foreach(CONFIG ${CONFIGS}) - get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG}) - message(STATUS "Found oneDNN: ${DNNL_LIB}") - endforeach() + # Verify oneDNN was compiled for the same target as llama + if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}") + target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl) + set(GGML_SYCL_DNNL 1) + get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS) + foreach(CONFIG ${CONFIGS}) + get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG}) + message(STATUS "Found oneDNN: ${DNNL_LIB}") + endforeach() + else() + message(WARNING + "oneDNN must be compiled for the same target as llama.cpp. + llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}. + Disabling oneDNN support.") + endif() else() - message(WARNING - "oneDNN must be compiled for the same target as llama.cpp. - llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}. - Disabling oneDNN support.") + message(STATUS "oneDNN not found, disabling oneDNN support") endif() else() - message(STATUS "oneDNN not found, disabling oneDNN support") + message(STATUS "oneDNN support disabled by the user") endif() target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL}) @@ -108,6 +111,9 @@ endif() if (GGML_SYCL_TARGET STREQUAL "INTEL") # Intel devices use Intel oneMKL directly instead of oneMath to avoid the limitation of linking Intel oneMKL statically # See https://github.com/uxlfoundation/oneMath/issues/654 + if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + set(SYCL_COMPILER ON) + endif() find_package(MKL REQUIRED) target_link_libraries(ggml-sycl PRIVATE MKL::MKL_SYCL::BLAS) target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_USE_INTEL_ONEMKL) @@ -136,7 +142,7 @@ else() FetchContent_Declare( ONEMATH GIT_REPOSITORY https://github.com/uxlfoundation/oneMath.git - GIT_TAG c255b1b4c41e2ee3059455c1f96a965d6a62568a + GIT_TAG 8efe85f5aaebb37f1d8c503b7af66315feabf142 ) FetchContent_MakeAvailable(ONEMATH) # Create alias to match with find_package targets name @@ -164,7 +170,7 @@ else() target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_NVIDIA) elseif (GGML_SYCL_TARGET STREQUAL "AMD") if (NOT GGML_SYCL_DEVICE_ARCH) - message(ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.") + message(FATAL_ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.") endif() target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath_blas_rocblas) target_compile_options(ggml-sycl PRIVATE "-fsycl-targets=amdgcn-amd-amdhsa") diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index 73d807cab..f78a36ddf 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -13,23 +13,25 @@ #ifndef GGML_SYCL_BACKEND_HPP #define GGML_SYCL_BACKEND_HPP -#include "concat.hpp" +#include "binbcast.hpp" #include "common.hpp" +#include "concat.hpp" #include "conv.hpp" #include "convert.hpp" +#include "cpy.hpp" #include "dequantize.hpp" #include "dmmv.hpp" +#include "element_wise.hpp" +#include "gla.hpp" +#include "im2col.hpp" #include "mmq.hpp" #include "mmvq.hpp" -#include "rope.hpp" #include "norm.hpp" +#include "outprod.hpp" +#include "quants.hpp" +#include "rope.hpp" #include "softmax.hpp" #include "tsembd.hpp" -#include "im2col.hpp" #include "wkv.hpp" -#include "outprod.hpp" -#include "element_wise.hpp" -#include "cpy.hpp" -#include "gla.hpp" -#endif // GGML_SYCL_BACKEND_HPP +#endif // GGML_SYCL_BACKEND_HPP diff --git a/ggml/src/ggml-sycl/binbcast.cpp b/ggml/src/ggml-sycl/binbcast.cpp new file mode 100644 index 000000000..0a3883ae1 --- /dev/null +++ b/ggml/src/ggml-sycl/binbcast.cpp @@ -0,0 +1,345 @@ +#include "binbcast.hpp" + +#include +#include +#include + +#include "ggml.h" + +template +static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, + int ne0, int ne1, int ne2, int ne3, + int ne10, int ne11, int ne12, int ne13, + /*int s0, */ int s1, int s2, int s3, + /*int s00,*/ int s01, int s02, int s03, + /*int s10,*/ int s11, int s12, int s13, + const sycl::nd_item<3> &item_ct1) { + const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1)); + const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) + + item_ct1.get_local_id(0)) / + ne3; + const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) + + item_ct1.get_local_id(0)) % + ne3; + + if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + return; + } + + const int i11 = i1 % ne11; + const int i12 = i2 % ne12; + const int i13 = i3 % ne13; + + const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; + const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; + const size_t i_dst = i3*s3 + i2*s2 + i1*s1; + + const src0_t * src0_row = src0 + i_src0; + const src1_t * src1_row = src1 + i_src1; + dst_t * dst_row = dst + i_dst; + + for (int i0 = i0s; i0 < ne0; + i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) { + const int i10 = i0 % ne10; + dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); + } +} + +template +static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst, + int ne0, int ne1, int ne2, int ne3, + int ne10, int ne11, int ne12, int ne13, + /*int s0, */ int s1, int s2, int s3, + /*int s00,*/ int s01, int s02, int s03, + /*int s10,*/ int s11, int s12, int s13, + const sycl::nd_item<3> &item_ct1) { + + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + const int i3 = i/(ne2*ne1*ne0); + const int i2 = (i/(ne1*ne0)) % ne2; + const int i1 = (i/ne0) % ne1; + const int i0 = i % ne0; + + if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + return; + } + + const int i11 = i1 % ne11; + const int i12 = i2 % ne12; + const int i13 = i3 % ne13; + + const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; + const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; + const size_t i_dst = i3*s3 + i2*s2 + i1*s1; + + const src0_t * src0_row = src0 + i_src0; + const src1_t * src1_row = src1 + i_src1; + dst_t * dst_row = dst + i_dst; + + const int i10 = i0 % ne10; + dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); +} + + +template +struct bin_bcast_sycl { + template + void operator()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00, + const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11, + const int64_t ne12, const int64_t ne13, const int64_t ne0, const int64_t ne1, const int64_t ne2, + const int64_t ne3, const size_t nb00, const size_t nb01, const size_t nb02, const size_t nb03, + const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0, + const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous, + const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) { + int nr0 = ne10 / ne0; + int nr1 = ne11/ne1; + int nr2 = ne12/ne2; + int nr3 = ne13/ne3; + + int nr[4] = { nr0, nr1, nr2, nr3 }; + + // collapse dimensions until first broadcast dimension + int64_t cne[] = {ne0, ne1, ne2, ne3}; + int64_t cne0[] = {ne00, ne01, ne02, ne03}; + int64_t cne1[] = {ne10, ne11, ne12, ne13}; + size_t cnb[] = {nb0, nb1, nb2, nb3}; + size_t cnb0[] = {nb00, nb01, nb02, nb03}; + size_t cnb1[] = {nb10, nb11, nb12, nb13}; + auto collapse = [](int64_t cne[]) { + cne[0] *= cne[1]; + cne[1] = cne[2]; + cne[2] = cne[3]; + cne[3] = 1; + }; + + auto collapse_nb = [](size_t cnb[], int64_t cne[]) { + cnb[1] *= cne[1]; + cnb[2] *= cne[2]; + cnb[3] *= cne[3]; + }; + + if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) { + for (int i = 0; i < 4; i++) { + if (nr[i] != 1) { + break; + } + if (i > 0) { + collapse_nb(cnb, cne); + collapse_nb(cnb0, cne0); + collapse_nb(cnb1, cne1); + collapse(cne); + collapse(cne0); + collapse(cne1); + } + } + } + { + int64_t ne0 = cne[0]; + int64_t ne1 = cne[1]; + int64_t ne2 = cne[2]; + int64_t ne3 = cne[3]; + + int64_t ne10 = cne1[0]; + int64_t ne11 = cne1[1]; + int64_t ne12 = cne1[2]; + int64_t ne13 = cne1[3]; + + size_t nb0 = cnb[0]; + size_t nb1 = cnb[1]; + size_t nb2 = cnb[2]; + size_t nb3 = cnb[3]; + + size_t nb00 = cnb0[0]; + size_t nb01 = cnb0[1]; + size_t nb02 = cnb0[2]; + size_t nb03 = cnb0[3]; + + size_t nb10 = cnb1[0]; + size_t nb11 = cnb1[1]; + size_t nb12 = cnb1[2]; + size_t nb13 = cnb1[3]; + + size_t s0 = nb0 / sizeof(dst_t); + size_t s1 = nb1 / sizeof(dst_t); + size_t s2 = nb2 / sizeof(dst_t); + size_t s3 = nb3 / sizeof(dst_t); + + size_t s10 = nb10 / sizeof(src1_t); + size_t s11 = nb11 / sizeof(src1_t); + size_t s12 = nb12 / sizeof(src1_t); + size_t s13 = nb13 / sizeof(src1_t); + + size_t s00 = nb00 / sizeof(src0_t); + size_t s01 = nb01 / sizeof(src0_t); + size_t s02 = nb02 / sizeof(src0_t); + size_t s03 = nb03 / sizeof(src0_t); + + GGML_UNUSED(s00); + + GGML_ASSERT(nb0 % sizeof(dst_t) == 0); + GGML_ASSERT(nb1 % sizeof(dst_t) == 0); + GGML_ASSERT(nb2 % sizeof(dst_t) == 0); + GGML_ASSERT(nb3 % sizeof(dst_t) == 0); + + GGML_ASSERT(nb00 % sizeof(src0_t) == 0); + GGML_ASSERT(nb01 % sizeof(src0_t) == 0); + GGML_ASSERT(nb02 % sizeof(src0_t) == 0); + GGML_ASSERT(nb03 % sizeof(src0_t) == 0); + + GGML_ASSERT(nb10 % sizeof(src1_t) == 0); + GGML_ASSERT(nb11 % sizeof(src1_t) == 0); + GGML_ASSERT(nb12 % sizeof(src1_t) == 0); + GGML_ASSERT(nb13 % sizeof(src1_t) == 0); + + GGML_ASSERT(s0 == 1); + GGML_ASSERT(s10 == 1); + + const int block_size = 128; + + int64_t hne0 = std::max(ne0/2LL, 1LL); + + sycl::range<3> block_dims(1, 1, 1); + block_dims[2] = std::min(hne0, block_size); + block_dims[1] = std::min( + ne1, block_size / (unsigned int)block_dims[2]); + block_dims[0] = std::min( + std::min( + ne2 * ne3, block_size / (unsigned int)block_dims[2] / + (unsigned int)block_dims[1]), + 64U); + + sycl::range<3> block_nums( + (ne2 * ne3 + block_dims[0] - 1) / block_dims[0], + (ne1 + block_dims[1] - 1) / block_dims[1], + (hne0 + block_dims[2] - 1) / block_dims[2]); + + if (block_nums[0] > 65535) { + // this is the maximum number of blocks in z direction, fallback to 1D grid kernel + int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size; + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) * + sycl::range<3>(1, 1, block_size), + sycl::range<3>(1, 1, block_size)), + [=](sycl::nd_item<3> item_ct1) { + k_bin_bcast_unravel( + src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, + ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02, + s03, s11, s12, s13, item_ct1); + }); + } + } else { + /* + DPCT1049:16: The work-group size passed to the SYCL kernel may + exceed the limit. To get the device limit, query + info::device::max_work_group_size. Adjust the work-group size if + needed. + */ + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_bin_bcast(src0_dd, src1_dd, dst_dd, ne0, ne1, + ne2, ne3, ne10, ne11, ne12, ne13, + s1, s2, s3, s01, s02, s03, s11, s12, s13, + item_ct1); + }); + } + } + } +}; + +template +inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, + ggml_tensor * dst) { + dpct::queue_ptr main_stream = ctx.stream(); + GGML_TENSOR_BINARY_OP_LOCALS + + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + op()((const float *) src0->data, (const float *) src1->data, (float *) dst->data, ne00, ne01, ne02, ne03, ne10, + ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, + ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + op()((const sycl::half *) src0->data, (const sycl::half *) src1->data, (sycl::half *) dst->data, ne00, ne01, + ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, + nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), + main_stream); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + op()((const sycl::half *) src0->data, (const float *) src1->data, (sycl::half *) dst->data, ne00, ne01, ne02, + ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, + nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { + op()((const int32_t *) src0->data, (const int32_t *) src1->data, (int32_t *) dst->data, ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, + nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); + } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) { + op()((const int16_t *) src0->data, (const int16_t *) src1->data, (int16_t *) dst->data, ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, + nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); + } else { + fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type), + ggml_type_name(src0->type), ggml_type_name(src1->type)); + GGML_ABORT("fatal error"); + } +} + +inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + + ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); +} + +inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + + ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); +} + +inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + + ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); +} + +inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + + ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); +} + +inline void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + ggml_sycl_op_bin_bcast>(ctx, dst, dst->src[0], dst); +} + + +void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + ggml_sycl_op_add(ctx, dst); +} + +void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + ggml_sycl_op_sub(ctx, dst); +} + +void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + ggml_sycl_op_mul(ctx, dst); +} + +void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + ggml_sycl_op_div(ctx, dst); +} + +void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_repeat(ctx, dst); +} + diff --git a/ggml/src/ggml-sycl/binbcast.hpp b/ggml/src/ggml-sycl/binbcast.hpp new file mode 100644 index 000000000..9cce0f053 --- /dev/null +++ b/ggml/src/ggml-sycl/binbcast.hpp @@ -0,0 +1,39 @@ +#ifndef GGML_SYCL_BINBCAST_HPP +#define GGML_SYCL_BINBCAST_HPP +#include "common.hpp" + + +static __dpct_inline__ float op_repeat(const float a, const float b) { + return b; + GGML_UNUSED(a); +} + +static __dpct_inline__ float op_add(const float a, const float b) { + return a + b; +} + +static __dpct_inline__ float op_sub(const float a, const float b) { + return a - b; +} + +static __dpct_inline__ float op_mul(const float a, const float b) { + return a * b; +} + +static __dpct_inline__ float op_div(const float a, const float b) { + return a / b; +} + +void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + + +#endif //GGML_SYCL_BINBCAST_HPP + diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 3e1ceeaa4..753b4af14 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -13,8 +13,10 @@ #ifndef GGML_SYCL_COMMON_HPP #define GGML_SYCL_COMMON_HPP +#include #include #include +#include #include "dpct/helper.hpp" #include "ggml-sycl.h" @@ -42,12 +44,22 @@ void ggml_sycl_host_free(void* ptr); extern int g_ggml_sycl_debug; extern int g_ggml_sycl_disable_optimize; +extern int g_ggml_sycl_prioritize_dmmv; -#define GGML_SYCL_DEBUG(...) \ - do { \ - if (g_ggml_sycl_debug) \ - fprintf(stderr, __VA_ARGS__); \ - } while (0) +#if defined(__clang__) && __has_builtin(__builtin_expect) +// Hint the optimizer to pipeline the more likely following instruction in branches +# define LIKELY(expr) __builtin_expect(expr, true) +# define UNLIKELY(expr) __builtin_expect(expr, false) +#else +# define LIKELY(expr) (expr) +# define UNLIKELY(expr) (expr) +#endif + +#define GGML_SYCL_DEBUG(...) \ + do { \ + if (UNLIKELY(g_ggml_sycl_debug)) \ + fprintf(stderr, __VA_ARGS__); \ + } while (0) #define CHECK_TRY_ERROR(expr) \ [&]() { \ @@ -80,10 +92,6 @@ extern int g_ggml_sycl_disable_optimize; // max batch size to use MMQ kernels when tensor cores are available #define MMQ_MAX_BATCH_SIZE 32 -#if defined(_MSC_VER) -#pragma warning(disable : 4244 4267) // possible loss of data -#endif - // dmmv = dequantize_mul_mat_vec #ifndef GGML_SYCL_DMMV_X #define GGML_SYCL_DMMV_X 32 @@ -118,17 +126,12 @@ static void crash() { GGML_ABORT("SYCL error"); } -#define SYCL_CHECK(err) \ - do { \ - auto err_ = (err); \ - if (err_ != 0) \ - ggml_sycl_error( \ - #err, \ - __func__, \ - __FILE__, \ - __LINE__, \ - "Meet error in this line code!"); \ - } while (0) +#define SYCL_CHECK(err) \ + do { \ + auto err_ = (err); \ + if (err_ != 0) \ + ggml_sycl_error(#err, __func__, __FILE__, __LINE__, "Exception caught in this line of code."); \ + } while (0) #if DPCT_COMPAT_RT_VERSION >= 11100 #define GGML_SYCL_ASSUME(x) __builtin_assume(x) @@ -146,8 +149,6 @@ typedef sycl::float2 dfloat2; #define MMVQ_MAX_BATCH_SIZE 8 -static const int8_t kvalues_iq4nl[16]={-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; - static int g_all_sycl_device_count = -1; static bool g_ggml_backend_sycl_buffer_type_initialized = false; @@ -313,7 +314,6 @@ struct ggml_backend_sycl_context { int device; std::string name; optimize_feature opt_feature; - bool optimized_graph=false; queue_ptr qptrs[GGML_SYCL_MAX_DEVICES][GGML_SYCL_MAX_STREAMS] = { { nullptr } }; @@ -480,6 +480,19 @@ static __dpct_inline__ float warp_reduce_max(float x, return x; } +/* Helper for Computing the linear offset of a ggml_tensor given +per-dimension sizes, strides, and indices */ +template +__dpct_inline__ size_t calculate_offset(const std::array & strides, const std::array & indices) { + size_t offset = 0; +#pragma unroll + for (int i = 0; i < N; i++) { + auto index_i = indices[i]; + offset += strides[i] * index_i; + } + return offset; +} + // Helper for vec loading aligned data template inline sycl::vec vec_aligned_load(const Tp* aligned_ptr) { @@ -494,286 +507,78 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor acc) { int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size); -template -static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, - int ne0, int ne1, int ne2, int ne3, - int ne10, int ne11, int ne12, int ne13, - /*int s0, */ int s1, int s2, int s3, - /*int s00,*/ int s01, int s02, int s03, - /*int s10,*/ int s11, int s12, int s13, - const sycl::nd_item<3> &item_ct1) { - const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) + - item_ct1.get_local_id(1)); - const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) + - item_ct1.get_local_id(0)) / - ne3; - const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) + - item_ct1.get_local_id(0)) % - ne3; - - if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { - return; - } - - const int i11 = i1 % ne11; - const int i12 = i2 % ne12; - const int i13 = i3 % ne13; - - const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; - const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; - const size_t i_dst = i3*s3 + i2*s2 + i1*s1; - - const src0_t * src0_row = src0 + i_src0; - const src1_t * src1_row = src1 + i_src1; - dst_t * dst_row = dst + i_dst; - - for (int i0 = i0s; i0 < ne0; - i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) { - const int i10 = i0 % ne10; - dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); - } -} - -template -static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst, - int ne0, int ne1, int ne2, int ne3, - int ne10, int ne11, int ne12, int ne13, - /*int s0, */ int s1, int s2, int s3, - /*int s00,*/ int s01, int s02, int s03, - /*int s10,*/ int s11, int s12, int s13, - const sycl::nd_item<3> &item_ct1) { - - const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); - - const int i3 = i/(ne2*ne1*ne0); - const int i2 = (i/(ne1*ne0)) % ne2; - const int i1 = (i/ne0) % ne1; - const int i0 = i % ne0; - - if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { - return; - } - - const int i11 = i1 % ne11; - const int i12 = i2 % ne12; - const int i13 = i3 % ne13; - - const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; - const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; - const size_t i_dst = i3*s3 + i2*s2 + i1*s1; - - const src0_t * src0_row = src0 + i_src0; - const src1_t * src1_row = src1 + i_src1; - dst_t * dst_row = dst + i_dst; - - const int i10 = i0 % ne10; - dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); -} - - -template -struct bin_bcast_sycl { - template - void operator()(ggml_backend_sycl_context & ctx, - const struct ggml_tensor *src0, - const struct ggml_tensor *src1, struct ggml_tensor *dst, - const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd, - queue_ptr stream) { - - GGML_TENSOR_BINARY_OP_LOCALS - - int nr0 = ne10/ne0; - int nr1 = ne11/ne1; - int nr2 = ne12/ne2; - int nr3 = ne13/ne3; - - int nr[4] = { nr0, nr1, nr2, nr3 }; - - // collapse dimensions until first broadcast dimension - int64_t cne[] = {ne0, ne1, ne2, ne3}; - int64_t cne0[] = {ne00, ne01, ne02, ne03}; - int64_t cne1[] = {ne10, ne11, ne12, ne13}; - size_t cnb[] = {nb0, nb1, nb2, nb3}; - size_t cnb0[] = {nb00, nb01, nb02, nb03}; - size_t cnb1[] = {nb10, nb11, nb12, nb13}; - auto collapse = [](int64_t cne[]) { - cne[0] *= cne[1]; - cne[1] = cne[2]; - cne[2] = cne[3]; - cne[3] = 1; - }; - - auto collapse_nb = [](size_t cnb[], int64_t cne[]) { - cnb[1] *= cne[1]; - cnb[2] *= cne[2]; - cnb[3] *= cne[3]; - }; - - if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { - for (int i = 0; i < 4; i++) { - if (nr[i] != 1) { - break; - } - if (i > 0) { - collapse_nb(cnb, cne); - collapse_nb(cnb0, cne0); - collapse_nb(cnb1, cne1); - collapse(cne); - collapse(cne0); - collapse(cne1); - } - } - } - { - int64_t ne0 = cne[0]; - int64_t ne1 = cne[1]; - int64_t ne2 = cne[2]; - int64_t ne3 = cne[3]; - - int64_t ne10 = cne1[0]; - int64_t ne11 = cne1[1]; - int64_t ne12 = cne1[2]; - int64_t ne13 = cne1[3]; - - size_t nb0 = cnb[0]; - size_t nb1 = cnb[1]; - size_t nb2 = cnb[2]; - size_t nb3 = cnb[3]; - - size_t nb00 = cnb0[0]; - size_t nb01 = cnb0[1]; - size_t nb02 = cnb0[2]; - size_t nb03 = cnb0[3]; - - size_t nb10 = cnb1[0]; - size_t nb11 = cnb1[1]; - size_t nb12 = cnb1[2]; - size_t nb13 = cnb1[3]; - - size_t s0 = nb0 / sizeof(dst_t); - size_t s1 = nb1 / sizeof(dst_t); - size_t s2 = nb2 / sizeof(dst_t); - size_t s3 = nb3 / sizeof(dst_t); - - size_t s10 = nb10 / sizeof(src1_t); - size_t s11 = nb11 / sizeof(src1_t); - size_t s12 = nb12 / sizeof(src1_t); - size_t s13 = nb13 / sizeof(src1_t); - - size_t s00 = nb00 / sizeof(src0_t); - size_t s01 = nb01 / sizeof(src0_t); - size_t s02 = nb02 / sizeof(src0_t); - size_t s03 = nb03 / sizeof(src0_t); - - GGML_UNUSED(s00); - - GGML_ASSERT(nb0 % sizeof(dst_t) == 0); - GGML_ASSERT(nb1 % sizeof(dst_t) == 0); - GGML_ASSERT(nb2 % sizeof(dst_t) == 0); - GGML_ASSERT(nb3 % sizeof(dst_t) == 0); - - GGML_ASSERT(nb00 % sizeof(src0_t) == 0); - GGML_ASSERT(nb01 % sizeof(src0_t) == 0); - GGML_ASSERT(nb02 % sizeof(src0_t) == 0); - GGML_ASSERT(nb03 % sizeof(src0_t) == 0); - - GGML_ASSERT(nb10 % sizeof(src1_t) == 0); - GGML_ASSERT(nb11 % sizeof(src1_t) == 0); - GGML_ASSERT(nb12 % sizeof(src1_t) == 0); - GGML_ASSERT(nb13 % sizeof(src1_t) == 0); - - GGML_ASSERT(s0 == 1); - GGML_ASSERT(s10 == 1); - - const int block_size = 128; - - int64_t hne0 = std::max(ne0/2LL, 1LL); - - sycl::range<3> block_dims(1, 1, 1); - block_dims[2] = std::min(hne0, block_size); - block_dims[1] = std::min( - ne1, block_size / (unsigned int)block_dims[2]); - block_dims[0] = std::min( - std::min( - ne2 * ne3, block_size / (unsigned int)block_dims[2] / - (unsigned int)block_dims[1]), - 64U); - - sycl::range<3> block_nums( - (ne2 * ne3 + block_dims[0] - 1) / block_dims[0], - (ne1 + block_dims[1] - 1) / block_dims[1], - (hne0 + block_dims[2] - 1) / block_dims[2]); - - if (block_nums[0] > 65535) { - // this is the maximum number of blocks in z direction, fallback to 1D grid kernel - int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size; - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) * - sycl::range<3>(1, 1, block_size), - sycl::range<3>(1, 1, block_size)), - [=](sycl::nd_item<3> item_ct1) { - k_bin_bcast_unravel( - src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, - ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02, - s03, s11, s12, s13, item_ct1); - }); - } - } else { - /* - DPCT1049:16: The work-group size passed to the SYCL kernel may - exceed the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if - needed. - */ - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); - - stream->parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - k_bin_bcast(src0_dd, src1_dd, dst_dd, ne0, ne1, - ne2, ne3, ne10, ne11, ne12, ne13, - s1, s2, s3, s01, s02, s03, s11, s12, s13, - item_ct1); - }); - } - } - GGML_UNUSED(ctx); - } -}; - -template -inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, - const ggml_tensor *src1, ggml_tensor *dst) { - dpct::queue_ptr main_stream = ctx.stream(); - - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - op()(ctx, src0, src1, dst, (const float *)src0->data, (const float *)src1->data, (float *)dst->data, main_stream); - } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { - op()(ctx, src0, src1, dst, (const sycl::half *)src0->data, (const float *)src1->data, - (sycl::half *)dst->data, main_stream); - } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { - op()(ctx, src0, src1, dst, (const sycl::half *)src0->data, (const float *)src1->data, (float *)dst->data, - main_stream); - } else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { - op()(ctx, src0, src1, dst, (const int32_t *)src0->data, (const int32_t *)src1->data, (int32_t *)dst->data, - main_stream); - } else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) { - op()(ctx, src0, src1, dst, (const int16_t *)src0->data, (const int16_t *)src1->data, (int16_t *)dst->data, - main_stream); - } else { - fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, - ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); - GGML_ABORT("fatal error"); - } +constexpr size_t ceil_div(const size_t m, const size_t n) { + return (m + n - 1) / n; } bool gpu_has_xmx(sycl::device &dev); + +template std::string debug_get_array_str(const std::string & prefix, const T array[N]) { + if (LIKELY(!g_ggml_sycl_debug)) { + return ""; + } + std::stringstream ss; + ss << prefix << "=["; + for (std::size_t i = 0; i < N - 1; ++i) { + ss << array[i] << ", "; + } + if constexpr (N > 0) { + ss << array[N - 1]; + } + ss << "]"; + return ss.str(); +} + +inline std::string debug_get_tensor_str(const std::string &prefix, + const ggml_tensor *tensor, const std::string &suffix = "") { + std::stringstream ss; + if (LIKELY(!g_ggml_sycl_debug)) { return ss.str(); } + ss << prefix.c_str() << "="; + if (tensor) { + ss << "'" << tensor->name << "':type=" << ggml_type_name(tensor->type); + ss << debug_get_array_str(";ne", tensor->ne); + ss << debug_get_array_str(";nb", tensor->nb); + + if (!ggml_is_contiguous(tensor)) { ss << ";strided"; } + if (ggml_is_permuted(tensor)) { ss << ";permuted"; } + } else { + ss << "nullptr"; + } + ss << suffix; + return ss.str(); +} + +// Use scope_op_debug_print to log operations coming from running a model +struct scope_op_debug_print { + // Use string_views to avoid the cost of creating a string and concatenating them + // string_views must be alive for as long as the object is alive + // scope_op_debug_print are used with string literals in practice which are stored in constant space so always accessible + scope_op_debug_print(const std::string_view & func, const std::string_view & func_suffix, const ggml_tensor * dst, + std::size_t num_src, const std::string_view & suffix = "") : + func(func), + func_suffix(func_suffix) { + if (LIKELY(!g_ggml_sycl_debug)) { + return; + } + GGML_SYCL_DEBUG("[SYCL][OP] call %s%s:", func.data(), func_suffix.data()); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(" dst", dst).c_str()); + if (dst) { + for (std::size_t i = 0; i < num_src; ++i) { + GGML_SYCL_DEBUG("%s", debug_get_tensor_str("\tsrc" + std::to_string(i), dst->src[i]).c_str()); + } + } + GGML_SYCL_DEBUG("%s\n", suffix.data()); + } + + scope_op_debug_print(const std::string_view & func, const ggml_tensor * dst, std::size_t num_src, + const std::string_view & suffix = "") : + scope_op_debug_print(func, "", dst, num_src, suffix) {} + + ~scope_op_debug_print() { GGML_SYCL_DEBUG("[SYCL][OP] call %s%s done\n", func.data(), func_suffix.data()); } + + private: + std::string_view func; + std::string_view func_suffix; +}; + #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/concat.cpp b/ggml/src/ggml-sycl/concat.cpp index d41cfd3a6..7aa91c861 100644 --- a/ggml/src/ggml-sycl/concat.cpp +++ b/ggml/src/ggml-sycl/concat.cpp @@ -159,39 +159,37 @@ static void concat_f32_sycl_non_cont( } void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - const ggml_tensor *src0 = dst->src[0]; - const ggml_tensor *src1 = dst->src[1]; - queue_ptr stream = ctx.stream(); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + queue_ptr stream = ctx.stream(); - const int32_t dim = ((int32_t *)dst->op_params)[0]; + const int32_t dim = ((int32_t *) dst->op_params)[0]; - if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { - const float *src0_d = (const float *)src0->data; - const float *src1_d = (const float *)src1->data; + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; - float *dst_d = (float *)dst->data; + float * dst_d = (float *) dst->data; - if (dim != 3) { - for (int i3 = 0; i3 < dst->ne[3]; i3++) { - concat_f32_sycl( - src0_d + i3 * (src0->nb[3] / 4), src1_d + i3 * (src1->nb[3] / 4), - dst_d + i3 * (dst->nb[3] / 4), src0->ne[0], src0->ne[1], - src0->ne[2], dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); - } + if (dim != 3) { + for (int i3 = 0; i3 < dst->ne[3]; i3++) { + concat_f32_sycl(src0_d + i3 * (src0->nb[3] / 4), src1_d + i3 * (src1->nb[3] / 4), + dst_d + i3 * (dst->nb[3] / 4), src0->ne[0], src0->ne[1], src0->ne[2], dst->ne[0], + dst->ne[1], dst->ne[2], dim, stream); + } + } else { + const size_t size0 = ggml_nbytes(src0); + const size_t size1 = ggml_nbytes(src1); + + SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(dst_d, src0_d, size0).wait())); + SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(dst_d + size0 / 4, src1_d, size1).wait())); + } } else { - const size_t size0 = ggml_nbytes(src0); - const size_t size1 = ggml_nbytes(src1); - - SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(dst_d, src0_d, size0).wait())); - SYCL_CHECK(CHECK_TRY_ERROR( - stream->memcpy(dst_d + size0 / 4, src1_d, size1).wait())); + concat_f32_sycl_non_cont(stream, (const char *) src0->data, (const char *) src1->data, (char *) dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], + src0->nb[2], src0->nb[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], + dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim); } - } else - concat_f32_sycl_non_cont( - stream, (const char *)src0->data, (const char *)src1->data, - (char *)dst->data, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], - src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src1->ne[0], - src1->ne[1], src1->ne[2], src1->ne[3], src1->nb[0], src1->nb[1], - src1->nb[2], src1->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], - dst->ne[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dim); } diff --git a/ggml/src/ggml-sycl/conv.cpp b/ggml/src/ggml-sycl/conv.cpp index ddba601e1..475bd34a2 100644 --- a/ggml/src/ggml-sycl/conv.cpp +++ b/ggml/src/ggml-sycl/conv.cpp @@ -72,6 +72,7 @@ static void conv_transpose_1d_f32_f32_sycl( } void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); const ggml_tensor *src0 = dst->src[0]; const ggml_tensor *src1 = dst->src[1]; const float * src0_d = (const float *)src0->data; diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index 76ac6a4dd..96d2583b1 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -183,6 +183,24 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k, } } +template +static void dequantize_row_q4_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; + const size_t local_size = 32; + const size_t global_size = nb * local_size; + + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + + stream->submit([&](sycl::handler & cgh) { + sycl::local_accessor scale_local_acc(sycl::range<1>(12), cgh); + + cgh.parallel_for(sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)), + [=](sycl::nd_item<1> item_ct1) { + dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb); + }); + }); +} + template static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { @@ -247,6 +265,17 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k, #endif } +template +static void dequantize_row_q6_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; + + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), + [=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); }); +} + template static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { @@ -437,41 +466,52 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k } template -static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, - const sycl::nd_item<3> &item_ct1) { +static void convert_unary_nc(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, + const int64_t ne02, const int64_t s01, const int64_t s02, const int64_t s03, + const sycl::nd_item<3> & item_ct1) { + const int64_t work_group_size = item_ct1.get_local_range(2); - const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2); + const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2); + + const int64_t i01 = item_ct1.get_group(1); + const int64_t i02 = item_ct1.get_group(0) % ne02; + const int64_t i03 = item_ct1.get_group(0) / ne02; // make each work-item deal with more elements since sycl global range can not exceed max int - const src_t * x = (const src_t *) vx; - for (int64_t i = global_id; i < k; i += work_group_size * item_ct1.get_group_range(2)) { - y[i] = x[i]; + const src_t * x = static_cast(vx); + const int64_t ix = i03 * s03 + i02 * s02 + i01 * s01; + const int64_t iy = ((i03 * ne02 + i02) * ne01 + i01) * ne00; + +#pragma unroll + for (int64_t i00 = global_id; i00 < ne00; i00 += work_group_size * item_ct1.get_group_range(2)) { + y[iy + i00] = static_cast(x[ix + i00]); } } template -static void convert_unary_sycl(const void *__restrict__ vx, - dst_t *__restrict__ y, const int64_t k, - dpct::queue_ptr stream) { - const int64_t num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE; +static void convert_unary_nc_sycl(const void * __restrict__ vx, dst_t * __restrict__ y, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t s01, const int64_t s02, const int64_t s03, dpct::queue_ptr queue) { + dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 }); + + sycl::range<3> global_size(ne02 * ne03, ne01, ceil_div(ne00, SYCL_DEQUANTIZE_BLOCK_SIZE)); // decrease global range when it exceeds the max int - int64_t local_size = downsample_sycl_global_range(num_blocks, SYCL_DEQUANTIZE_BLOCK_SIZE); - sycl::range<3> block_nums(1, 1, num_blocks); - sycl::range<3> local_range(1, 1, local_size); - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); + // TODO: Downsample logic is separated from the kernel, a rewrite is desirable + int64_t downsized_workgroup = downsample_sycl_global_range(global_size[0], SYCL_DEQUANTIZE_BLOCK_SIZE); + sycl::range<3> workgroup_size(1, 1, downsized_workgroup); - stream->parallel_for( - sycl::nd_range<3>(block_nums * local_range, local_range), - [=](sycl::nd_item<3> item_ct1) { - convert_unary(vx, y, k, item_ct1); - }); - } + queue->parallel_for(sycl::nd_range<3>(global_size * workgroup_size, workgroup_size), [=](sycl::nd_item<3> item_ct1) { + convert_unary_nc(vx, y, ne00, ne01, ne02, s01, s02, s03, item_ct1); + }); } -to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst) { +template +static void convert_unary_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr queue) { + convert_unary_nc_sycl(vx, y, k, 1, 1, 1, k, k, k, queue); +} + +to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { switch (type) { case GGML_TYPE_Q4_0: if (dst->src[0]->extra && @@ -493,11 +533,19 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst) { case GGML_TYPE_Q3_K: return dequantize_row_q3_K_sycl; case GGML_TYPE_Q4_K: - return dequantize_row_q4_K_sycl; + if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q4_K_sycl_reorder; + } else { + return dequantize_row_q4_K_sycl; + } case GGML_TYPE_Q5_K: return dequantize_row_q5_K_sycl; case GGML_TYPE_Q6_K: - return dequantize_row_q6_K_sycl; + if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q6_K_sycl_reorder; + } else { + return dequantize_row_q6_K_sycl; + } case GGML_TYPE_IQ1_S: return dequantize_row_iq1_s_sycl; case GGML_TYPE_IQ1_M: @@ -545,11 +593,20 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { case GGML_TYPE_Q3_K: return dequantize_row_q3_K_sycl; case GGML_TYPE_Q4_K: - return dequantize_row_q4_K_sycl; + if (dst->src[0]->extra && + ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q4_K_sycl_reorder; + } else { + return dequantize_row_q4_K_sycl; + } case GGML_TYPE_Q5_K: return dequantize_row_q5_K_sycl; case GGML_TYPE_Q6_K: - return dequantize_row_q6_K_sycl; + if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q6_K_sycl_reorder; + } else { + return dequantize_row_q6_K_sycl; + } case GGML_TYPE_IQ1_S: return dequantize_row_iq1_s_sycl; case GGML_TYPE_IQ1_M: @@ -574,3 +631,12 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { return nullptr; } } + +to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) { + switch (type) { + case GGML_TYPE_F32: + return convert_unary_nc_sycl; + default: + return nullptr; + } +} diff --git a/ggml/src/ggml-sycl/convert.hpp b/ggml/src/ggml-sycl/convert.hpp index 355dae22b..f8cb573e3 100644 --- a/ggml/src/ggml-sycl/convert.hpp +++ b/ggml/src/ggml-sycl/convert.hpp @@ -1,6 +1,6 @@ // // MIT license -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: MIT // @@ -16,12 +16,19 @@ #include "common.hpp" template -using to_t_sycl_t = void (*)(const void *__restrict__ x, T *__restrict__ y, - int64_t k, dpct::queue_ptr stream); -typedef to_t_sycl_t to_fp32_sycl_t; +using to_t_sycl_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, dpct::queue_ptr stream); +typedef to_t_sycl_t to_fp32_sycl_t; typedef to_t_sycl_t to_fp16_sycl_t; -to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor *dst); -to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst); +to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst); +to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor * dst); -#endif // GGML_SYCL_CONVERT_HPP +// Nc = Non-contiguous +template +using to_t_nc_sycl_t = void (*)(const void * x, T * y, int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, + int64_t s01, int64_t s02, int64_t s03, dpct::queue_ptr queue); + +typedef to_t_nc_sycl_t to_fp16_nc_sycl_t; +to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type); + +#endif // GGML_SYCL_CONVERT_HPP diff --git a/ggml/src/ggml-sycl/cpy.cpp b/ggml/src/ggml-sycl/cpy.cpp index 5a2314589..bec137140 100644 --- a/ggml/src/ggml-sycl/cpy.cpp +++ b/ggml/src/ggml-sycl/cpy.cpp @@ -1,8 +1,12 @@ #include "cpy.hpp" #include +#include #include "dequantize.hpp" +#include "ggml-sycl/common.hpp" +#include "ggml-sycl/presets.hpp" +#include "ggml.h" static __dpct_inline__ int best_index_int8(int n, const int8_t * val, float x) { if (x <= val[0]) { @@ -116,6 +120,15 @@ static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { } } +/* quantized type same copy */ +template +static void cpy_blck_q_q(const char * cxi, char * cdsti) { + const T * xi = (const T *) cxi; + T * dsti = (T *) cdsti; + *dsti = *xi; +} + + static void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { float * cdstf = (float *) (cdsti); @@ -311,6 +324,34 @@ template static void cpy_blck_q_f32(const } } + +template +static void cpy_q_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, + const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, + const sycl::nd_item<3> & item_ct1) { + const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2)) * qk; + + if (i >= ne) { + return; + } + + const int i03 = i / (ne00 * ne01 * ne02); + const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01); + const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00; + const int i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00; + const int x_offset = (i00 / qk) * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03; + + + const int i13 = i / (ne10 * ne11 * ne12); + const int i12 = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11); + const int i11 = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10; + const int i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10; + const int dst_offset = (i10 / qk) * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13; + + cpy_blck_q_q(cx + x_offset, cdst + dst_offset); +} + template static void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, @@ -322,6 +363,7 @@ static void cpy_f32_q(const char * cx, char * cdst, const int ne, const int ne00 return; } + const int i03 = i / (ne00 * ne01 * ne02); const int i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01); const int i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00; @@ -615,7 +657,73 @@ static void ggml_cpy_i32_i32_sycl(const char * cx, char * cdst, const int ne, co } } +static void ggml_cpy_q8_0_q8_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { + cpy_q_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); +} + + +static void ggml_cpy_q5_0_q5_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { + cpy_q_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); +} + + +static void ggml_cpy_q5_1_q5_1(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE); + + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { + cpy_q_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); +} + + +static void ggml_cpy_q4_0_q4_0(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { + cpy_q_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); +} + + +static void ggml_cpy_q4_1_q4_1(const char * cx, char * cdst, const int ne, const int ne00, const int ne01, + const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, queue_ptr stream) { + + const int num_blocks = ceil_div(ne, SYCL_CPY_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { + cpy_q_q(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, item_ct1); + }); +} + void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1) try { + // Unlike other operators ggml_sycl_cpy takes 2 distinct tensors instead of a dst ggml_tensor and rely on its src field + scope_op_debug_print scope_dbg_print(__func__, src1, /*num_src=*/0, debug_get_tensor_str("\tsrc0", src0)); const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); @@ -629,10 +737,10 @@ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, co char * src0_ddc = (char *) src0->data; char * src1_ddc = (char *) src1->data; - GGML_SYCL_DEBUG("[SYCL] %s: Tensor supplied: %s to %s\n", __func__, ggml_type_name(src0->type), - ggml_type_name(src1->type)); - - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { + if ((src0->type == src1->type) && (ggml_is_contiguous(src0) && ggml_is_contiguous(src1))) { + GGML_SYCL_DEBUG("%s: memcpy path\n", __func__); + main_stream->memcpy(src1_ddc, src0_ddc, ggml_nbytes(src0)); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { ggml_cpy_f32_f32_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { @@ -683,6 +791,16 @@ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, co } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { ggml_cpy_f32_iq4_nl_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) { + ggml_cpy_q8_0_q8_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_Q5_0) { + ggml_cpy_q5_0_q5_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_Q5_1) { + ggml_cpy_q5_1_q5_1(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_Q4_0) { + ggml_cpy_q4_0_q4_0(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_Q4_1) { + ggml_cpy_q4_1_q4_1(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else { GGML_LOG_ERROR("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); @@ -694,8 +812,6 @@ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, co } void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - // TODO: why do we pass dst as src1 here? - GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_cpy(ctx, dst->src[0], dst); - GGML_SYCL_DEBUG("[SYCL] call %s done\n", __func__); } diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index 651c2160d..540539bb2 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -357,6 +357,28 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8 } #endif +template +inline void dequantize_q4_K_common(dst_t * __restrict__ y, const uint8_t * __restrict__ qs_ptr, const float dall, + const float dmin, uint8_t * __restrict__ scales_local, int il, int ir) { + const int is = 2 * il; + constexpr int n = 4; + + uint8_t sc, m; + get_scale_min_k4(is + 0, scales_local, sc, m); + const float d1 = dall * sc; + const float m1 = dmin * m; + + get_scale_min_k4(is + 1, scales_local, sc, m); + const float d2 = dall * sc; + const float m2 = dmin * m; + + sycl::vec q_vec = vec_aligned_load(qs_ptr + 32 * il + n * ir); + for (int l = 0; l < n; ++l) { + y[l + 0] = d1 * (q_vec[l] & 0xF) - m1; + y[l + 32] = d2 * (q_vec[l] >> 4) - m2; + } +} + template static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy, uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) { @@ -365,36 +387,22 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri const int64_t i = item_ct1.get_group(2); #if QK_K == 256 - // assume 32 threads const int64_t tid = item_ct1.get_local_id(2); - const int64_t il = tid/8; - const int64_t ir = tid%8; - const int64_t is = 2*il; - const int64_t n = 4; + const int64_t il = tid / 8; + const int64_t ir = tid % 8; - dst_t * y = yy + i*QK_K + 64*il + n*ir; + dst_t * y = yy + i * QK_K + 64 * il + 4 * ir; const sycl::half2 dm = x[i].dm; const float dall = dm[0]; const float dmin = dm[1]; - if (tid < 12) + if (tid < 12) { scales_local[tid] = x[i].scales[tid]; - item_ct1.barrier(sycl::access::fence_space::local_space); - - uint8_t sc, m; - get_scale_min_k4(is + 0, scales_local, sc, m); - const float d1 = dall * sc; - const float m1 = dmin * m; - get_scale_min_k4(is + 1, scales_local, sc, m); - const float d2 = dall * sc; - const float m2 = dmin * m; - - sycl::vec q_vec = vec_aligned_load(x[i].qs + 32*il + n*ir); - for (int l = 0; l < n; ++l) { - y[l + 0] = d1 * (q_vec[l] & 0xF) - m1; - y[l +32] = d2 * (q_vec[l] >> 4) - m2; } + + item_ct1.barrier(sycl::access::fence_space::local_space); + dequantize_q4_K_common(y, x[i].qs, dall, dmin, scales_local, il, ir); #else const int64_t tid = item_ct1.get_local_id(2); const uint8_t * q = x[i].qs; @@ -406,6 +414,36 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri #endif } +template +static void dequantize_block_q4_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, uint8_t * scales_local, + const sycl::nd_item<1> & item_ct1, int64_t nb) { + const int64_t i = item_ct1.get_group(0); // block index + const int64_t tid = item_ct1.get_local_id(0); // thread index within block + const int64_t il = tid / 8; + const int64_t ir = tid % 8; + + dst_t * y = yy + i * QK_K + 64 * il + 4 * ir; + + const uint8_t * base = static_cast(vx); + const size_t qs_offset = i * (QK_K / 2); + const size_t scales_offset = nb * (QK_K / 2) + i * K_SCALE_SIZE; + const size_t dm_offset = nb * (QK_K / 2) + nb * K_SCALE_SIZE + i * sizeof(ggml_half2); + + const uint8_t * qs_ptr = base + qs_offset; + const uint8_t * scales_ptr = base + scales_offset; + ggml_half2 dm_values = *reinterpret_cast(base + dm_offset); + + const float dall = dm_values.x(); + const float dmin = dm_values.y(); + + if (tid < 12) { + scales_local[tid] = scales_ptr[tid]; + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + dequantize_q4_K_common(y, qs_ptr, dall, dmin, scales_local, il, ir); +} + template static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy, const sycl::nd_item<3> &item_ct1) { @@ -500,6 +538,38 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri #endif } +template +static void dequantize_block_q6_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, + const sycl::nd_item<3> & item_ct1, int64_t n_blocks) { + const int64_t ib = item_ct1.get_group(2); + + const int64_t tid = item_ct1.get_local_id(2); + const int64_t ip = tid / 32; // ip is 0 or 1 + const int64_t il = tid - 32 * ip; // 0...32 + const int64_t is = 8 * ip + il / 16; + + const uint8_t * base_ptr = static_cast(vx); + const auto ql_offset = ib * (QK_K / 2); + const auto qh_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * ib; + const auto base_scales_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * n_blocks + (QK_K / 16) * ib; + const auto base_d_offset = ((QK_K / 2) + (QK_K / 4) + (QK_K / 16)) * n_blocks; + const uint8_t * ql_ptr = base_ptr + ql_offset; + const uint8_t * qh_ptr = base_ptr + qh_offset; + const uint8_t * scales_ptr = base_ptr + base_scales_offset; + const ggml_half * d = (const ggml_half *) (base_ptr + base_d_offset) + ib; + + dst_t * y = yy + ib * QK_K + 128 * ip + il; + + const uint8_t * ql = ql_ptr + 64 * ip + il; + const uint8_t qh = *(qh_ptr + 32 * ip + il); + const int8_t * sc = reinterpret_cast(scales_ptr + is); + + y[0] = *d * sc[0] * ((int8_t) ((ql[0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32); + y[32] = *d * sc[2] * ((int8_t) ((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32); + y[64] = *d * sc[4] * ((int8_t) ((ql[0] >> 4) | (((qh >> 4) & 3) << 4)) - 32); + y[96] = *d * sc[6] * ((int8_t) ((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); +} + template static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy, const sycl::nd_item<3> &item_ct1, diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 04a85fa35..4f2760110 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -1092,6 +1092,8 @@ void ggml_sycl_op_dequantize_mul_mat_vec( src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; if (src1_convert_f16) { + scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2, + " : converting src1 to fp16"); src1_dfloat = src1_dfloat_a.alloc(ne00); const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst); GGML_ASSERT(to_fp16_sycl != nullptr); @@ -1129,7 +1131,13 @@ void ggml_sycl_op_dequantize_mul_mat_vec( dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); break; case GGML_TYPE_Q4_K: - dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + // reorder is currently not supported for dmmv + GGML_ABORT("Unimplemented dequantize case case for q4_k reorder"); + } else { + dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + } break; case GGML_TYPE_Q5_K: dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index b36aa7a9d..5b7c4f0b4 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -21,6 +21,27 @@ static void acc_f32(const float * x, const float * y, float * dst, const int ne, } } +template +static void sgn(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { + for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) { + dst[i] = x[i] > static_cast(0.f) ? static_cast(1.f) : ((x[i] < static_cast(0.f) ? static_cast(-1.f) : static_cast(0.f))); + } +} + +template +static void abs_op(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { + for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) { + dst[i] = sycl::fabs(x[i]); + } +} + +template +static void elu_op(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { + for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) { + dst[i] = (x[i] > static_cast(0.f)) ? x[i] : sycl::expm1(x[i]); + } +} + template static void gelu(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { @@ -63,6 +84,15 @@ static void gelu_quick(const T *x, T *dst, int k, dst[i] = x[i] * (static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x[i]))); } +template +static void gelu_erf(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) { + const T SQRT_2_INV = static_cast(0.70710678118654752440084436210484f); + for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) { + auto x_i = x[i]; + dst[i] = static_cast(0.5f) * x_i * (static_cast(1.0f) + sycl::erf(x_i * SQRT_2_INV)); + } +} + template static void tanh(const T *x, T *dst, int k, const sycl::nd_item<3> &item_ct1) { @@ -335,6 +365,37 @@ static void silu_sycl(const T *x, T *dst, const int k, }); } +template +static void sgn_sycl(const T * x, T * dst, const int k, queue_ptr stream) { + // hard code for now + const int num_blocks = ceil_div(k, 256); + stream->parallel_for( + sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range(1, 1, 256)), sycl::range(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { + sgn(x, dst, k, item_ct1); + }); +} + +template +static void abs_sycl(const T * x, T * dst, const int k, queue_ptr stream) { + // hard code for now + const int num_blocks = ceil_div(k, 256); + stream->parallel_for( + sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { + abs_op(x, dst, k, item_ct1); + }); +} + + +template +static void elu_sycl(const T * x, T * dst, const int k, queue_ptr stream) { + // hard code for now + const int num_blocks = ceil_div(k, 256); + stream->parallel_for( + sycl::nd_range<3>((sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, 256)), sycl::range<3>(1, 1, 256)), [=](sycl::nd_item<3> item_ct1) { + elu_op(x, dst, k, item_ct1); + }); +} + template static void gelu_quick_sycl(const T *x, T *dst, const int k, queue_ptr stream) { @@ -348,6 +409,20 @@ static void gelu_quick_sycl(const T *x, T *dst, const int k, }); } + +template +static void gelu_erf_sycl(const T *x, T *dst, const int k, + queue_ptr stream) { + const int num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + gelu_erf(x, dst, k, item_ct1); + }); +} + template static void tanh_sycl(const T *x, T *dst, const int k, queue_ptr stream) { @@ -574,6 +649,103 @@ static void clamp_sycl(const T *x, T *dst, const float min, }); } +inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + sgn_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + sgn_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + +inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + abs_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + abs_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + + +inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + elu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + elu_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { #if defined (GGML_SYCL_F16) GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); @@ -602,7 +774,6 @@ inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -634,7 +805,6 @@ inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -666,10 +836,41 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } +inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { +#if defined (GGML_SYCL_F16) + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); +#else + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); +#endif + GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + switch (dst->type) { +#if defined (GGML_SYCL_F16) + case GGML_TYPE_F16: + { + auto data_pts = cast_data(dst); + gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } +#endif + case GGML_TYPE_F32: + { + auto data_pts = cast_data(dst); + gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream); + break; + } + default: + GGML_ABORT("GGML tensor type not supported!\n"); + } +} + + inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { #if defined (GGML_SYCL_F16) GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); @@ -698,7 +899,6 @@ inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -731,7 +931,6 @@ inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -765,7 +964,6 @@ inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tenso } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -797,7 +995,6 @@ inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -829,7 +1026,6 @@ inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -861,7 +1057,6 @@ inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -893,7 +1088,6 @@ inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -926,7 +1120,6 @@ inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -958,7 +1151,6 @@ inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -990,7 +1182,6 @@ inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1022,7 +1213,6 @@ inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1054,7 +1244,6 @@ inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1089,7 +1278,6 @@ inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1121,7 +1309,6 @@ inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1163,7 +1350,6 @@ inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1198,7 +1384,6 @@ inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1236,7 +1421,6 @@ inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * ds } default: GGML_ABORT("GGML tensor type not supported!\n"); - break; } } @@ -1261,177 +1445,127 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) } -inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - - ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); -} - -inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - - ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); -} - -inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - - ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); -} - -inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - - ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst); -} - - void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_sqrt(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_sin(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_cos(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_acc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); ggml_sycl_op_acc(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_gelu(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_silu(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_gelu_quick(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); +} + +void ggml_sycl_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_gelu_erf(ctx, dst); } void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_tanh(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_relu(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_sigmoid(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_hardsigmoid(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_hardswish(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } - void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_exp(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_log(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_neg(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_step(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_leaky_relu(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_sqr(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_upscale(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_pad(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type)); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_clamp(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } - - - -void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_add(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); +void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_sgn(ctx, dst); } -void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_sub(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); +void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_abs(ctx, dst); } -void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_mul(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - -void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_div(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); +void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_elu(ctx, dst); } diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp index 76d91ba10..bd40113f0 100644 --- a/ggml/src/ggml-sycl/element_wise.hpp +++ b/ggml/src/ggml-sycl/element_wise.hpp @@ -10,27 +10,6 @@ T neg_infinity() { return -std::numeric_limits::infinity(); } -static __dpct_inline__ float op_repeat(const float a, const float b) { - return b; - GGML_UNUSED(a); -} - -static __dpct_inline__ float op_add(const float a, const float b) { - return a + b; -} - -static __dpct_inline__ float op_sub(const float a, const float b) { - return a - b; -} - -static __dpct_inline__ float op_mul(const float a, const float b) { - return a * b; -} - -static __dpct_inline__ float op_div(const float a, const float b) { - return a / b; -} - template struct typed_data { const T * src; @@ -59,6 +38,8 @@ void ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); @@ -87,14 +68,10 @@ void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst); -// --------- +void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst); -void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst); - -void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst); - -void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst); - -void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); #endif // GGML_SYCL_ELEMENTWISE_HPP + diff --git a/ggml/src/ggml-sycl/gemm.hpp b/ggml/src/ggml-sycl/gemm.hpp index 4ebbb5b66..5efe03d36 100644 --- a/ggml/src/ggml-sycl/gemm.hpp +++ b/ggml/src/ggml-sycl/gemm.hpp @@ -32,19 +32,42 @@ public: else static_assert(0); } - static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k, - const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) { + // matrix A has m rows, k columns + // matrix B has k rows, n columns + // nra - number of elements to skip when moving into next row in A + // nrb - number of elements to skip when moving into next row in B + // nca - number of elements to skip when moving into next column in A + // ncb - number of elements to skip when moving into next column in B + // stride_a - number of elements to skip when moving to next A matrix + // stride_b - number of elements to skip when moving to next B matrix + // batches_a - number of A matrices + // batches_b - number of B matrices + static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k, + const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a, + const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b, + void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) { + auto stream = ctx.stream_dnnl(q); auto eng = ctx.engine_dnnl(q); - dnnl::memory::dims a_dims = { m, k }; - dnnl::memory::dims b_dims = { k, n }; - dnnl::memory::dims c_dims = { m, n }; - const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab); - const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab); - const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab); + + // { # strides, # rows, # columns } + dnnl::memory::dims a_dims = { batches_a, m, k }; + dnnl::memory::dims b_dims = { batches_b, k, n }; + dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n }; + + // { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column } + dnnl::memory::dims a_strides = { stride_a, nra, nca }; + dnnl::memory::dims b_strides = { stride_b, nrb, ncb }; + + const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides); + const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides); + const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc); dnnl::primitive_attr primitive_attr; primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); +#ifdef GGML_SYCL_F16 + primitive_attr.set_fpmath_mode(dnnl::fpmath_mode::f16); +#endif auto a_mem = dnnl::memory(a_in_md, eng, const_cast(a)); auto b_mem = dnnl::memory(b_in_md, eng, const_cast(b)); @@ -63,6 +86,15 @@ public: matmul_prim.execute(stream, matmul_args); } + + // matrices A and B are column major, both having k rows + // matrix A has m column, matrix B has n columns + // output: column major matrix C = A transposed * B + static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k, + const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) { + + gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1); + } }; #endif diff --git a/ggml/src/ggml-sycl/getrows.cpp b/ggml/src/ggml-sycl/getrows.cpp index 64665be46..4a7712781 100644 --- a/ggml/src/ggml-sycl/getrows.cpp +++ b/ggml/src/ggml-sycl/getrows.cpp @@ -257,8 +257,7 @@ static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tens GGML_UNUSED(ctx); } -void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[1]->type == GGML_TYPE_I32); GGML_ASSERT(dst->type == GGML_TYPE_F32); @@ -308,4 +307,3 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { GGML_ABORT("fatal error"); } } - diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 3e48a9244..4b7610362 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -38,6 +38,7 @@ #include "ggml-sycl/backend.hpp" #include "ggml-sycl/common.hpp" +#include "ggml-sycl/element_wise.hpp" #include "ggml-sycl/presets.hpp" #include "ggml-sycl/gemm.hpp" #include "ggml-sycl/sycl_hw.hpp" @@ -48,6 +49,8 @@ static bool g_sycl_loaded = false; int g_ggml_sycl_debug = 0; int g_ggml_sycl_disable_optimize = 0; int g_ggml_sycl_disable_graph = 0; +int g_ggml_sycl_disable_dnn = 0; +int g_ggml_sycl_prioritize_dmmv = 0; static ggml_sycl_device_info ggml_sycl_init() { ggml_sycl_device_info info = {}; @@ -194,11 +197,23 @@ static void ggml_check_sycl() try { g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0); g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1); g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1); + g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0); + g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0); GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n"); GGML_LOG_INFO("Running with Environment Variables:\n"); GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize); +#ifdef GGML_SYCL_GRAPH GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph); +#else + GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n"); +#endif +#if GGML_SYCL_DNNL + GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn); +#else + GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n"); +#endif + GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv); GGML_LOG_INFO("Build with Macros:\n"); #if defined(GGML_SYCL_FORCE_MMQ) GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n"); @@ -331,13 +346,16 @@ static void * ggml_backend_sycl_buffer_get_base(ggml_backend_buffer_t buffer) { static enum ggml_status ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor *tensor) try { + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor, "\n").c_str()); ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context; if (tensor->view_src != NULL) { assert(tensor->view_src->buffer->buft == buffer->buft); return GGML_STATUS_SUCCESS; } - if (tensor->type == GGML_TYPE_Q4_0) { + if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) && + !g_ggml_sycl_disable_optimize) { ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; tensor->extra = extra; ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx. @@ -366,20 +384,23 @@ static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor *tensor, const void *data, size_t offset, size_t size) try { - + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str()); + GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset); ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; ggml_sycl_set_device(ctx->device); auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue()); - SYCL_CHECK( - CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw())); + SYCL_CHECK(CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw())); +#ifndef _WIN32 // Note: Use host buffer to save the data from mmap(), then copy to device. It's workaround for mmap() issue on PVC GPU. // This function will be called during load model from disk. Use memory buffer replace dynamic won't save more time and brings potential memory leak risk here. - char* host_buf = (char*)malloc(size); + char * host_buf = (char *) malloc(size); memcpy(host_buf, data, size); - SYCL_CHECK( - CHECK_TRY_ERROR((*stream).memcpy((char *)tensor->data + offset, host_buf, size) - .wait())); + SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy((char *) tensor->data + offset, host_buf, size).wait())); free(host_buf); +#else + SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy((char *) tensor->data + offset, data, size).wait())); +#endif } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ @@ -391,7 +412,9 @@ static void ggml_backend_sycl_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor *tensor, void *data, size_t offset, size_t size) try { - + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str()); + GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset); ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; ggml_sycl_set_device(ctx->device); @@ -419,7 +442,12 @@ static bool ggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor *src, ggml_tensor *dst) try { - if (ggml_backend_buffer_is_sycl(src->buffer)) { + bool is_cpy_supported = ggml_backend_buffer_is_sycl(src->buffer); + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": dst", dst).c_str()); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(" src", src).c_str()); + GGML_SYCL_DEBUG(" is_cpy_supported=%d\n", is_cpy_supported); + if (is_cpy_supported) { ggml_backend_sycl_buffer_context * src_ctx = (ggml_backend_sycl_buffer_context *)src->buffer->context; ggml_backend_sycl_buffer_context * dst_ctx = (ggml_backend_sycl_buffer_context *)dst->buffer->context; @@ -476,7 +504,8 @@ ggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer, static void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) try { - ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context; + GGML_SYCL_DEBUG("[SYCL] call %s: size=%zu\n", __func__, buffer->size); + ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context; ggml_sycl_set_device(ctx->device); queue_ptr stream = ctx->stream; @@ -495,7 +524,9 @@ catch (sycl::exception const &exc) { static void ggml_backend_sycl_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { - GGML_SYCL_DEBUG(" [SYCL] call %s\n", __func__); + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str()); + GGML_SYCL_DEBUG(" size=%zu offset=%zu value=%u\n", size, offset, value); ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *) buffer->context; SYCL_CHECK(ggml_sycl_set_device(ctx->device)); auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue()); @@ -773,6 +804,8 @@ static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buff static enum ggml_status ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor *tensor) try { + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor, "\n").c_str()); GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context; @@ -857,6 +890,9 @@ static void ggml_backend_sycl_split_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor *tensor, const void *data, size_t offset, size_t size) try { + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str()); + GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset); // split tensors must always be set in their entirety at once GGML_ASSERT(offset == 0); GGML_ASSERT(size == ggml_nbytes(tensor)); @@ -910,6 +946,9 @@ static void ggml_backend_sycl_split_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor *tensor, void *data, size_t offset, size_t size) try { + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str()); + GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset); // split tensors must always be set in their entirety at once GGML_ASSERT(offset == 0); GGML_ASSERT(size == ggml_nbytes(tensor)); @@ -1396,6 +1435,59 @@ static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, reinterpret_cast(y[ib].ds.y()) = sum; } +template +static __dpct_inline__ void quantize_and_reorder_q8_1(const float * __restrict__ x, void * reordered_q8_tensor, + const int kx, const int kx_padded, const sycl::nd_item<1> & it) { + /* + Quantizes and reorders the resultant q8 tensor in a per row fashion + Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values + */ + + auto subgroup_id = it.get_group(0); + auto wi_id = it.get_local_id(0); + + const int num_blocks_per_row = kx / QK8_1; + auto row = subgroup_id / num_blocks_per_row; + auto col = subgroup_id % num_blocks_per_row; + + auto row_offset = row * (kx_padded / QK8_1) * sizeof(block_q8_1); + auto col_offset = QK8_1 * col + wi_id * ElementsPerWI; + + auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset); + auto ds_ptr = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof(sycl::half2)); + + sycl::vec wi_f32_vals; + sycl::vec quantized_values; + + auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id; + wi_f32_vals = *reinterpret_cast *>(x + float_ptr_offset); + + float sum = 0.0f; + float amax = 0.0f; + +#pragma unroll(ElementsPerWI) + for (int i = 0; i < ElementsPerWI; i++) { + sum += wi_f32_vals[i]; + amax = sycl::fmax(amax, sycl::fabs(wi_f32_vals[i])); + quantized_values[i] = 0; + } + sum = sycl::reduce_over_group(it.get_group(), sum, sycl::plus()); + amax = sycl::reduce_over_group(it.get_group(), amax, sycl::maximum()); + float d = amax == 0 ? 1 : amax / 127; + +#pragma unroll(ElementsPerWI) + for (int i = 0; i < ElementsPerWI; i++) { + quantized_values[i] = sycl::round(wi_f32_vals[i] / d); + } + + d = amax == 0 ? 0 : d; + + *reinterpret_cast *>(quant_ptr) = quantized_values; + if (wi_id == 0) { + *ds_ptr = sycl::half2(sycl::half(d), sycl::half(sum)); + } +} + static void mul_mat_p021_f16_f32( const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y, @@ -1680,23 +1772,30 @@ static void pool2d_nchw_kernel( o_ptr[cur_oh * ow + cur_ow] = res; } -static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx, - const int ky, const int kx_padded, - queue_ptr stream) { - const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE; - const sycl::range<3> num_blocks(1, ky, block_num_x); - int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE; - static_assert(QK8_1 % WARP_SIZE == 0); - const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE); - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); +static void quantize_row_q8_1_sycl(const float * x, void * vy, const int kx, const int ky, const int kx_padded, + bool reorder_q8_tensor, queue_ptr stream) { + if (reorder_q8_tensor) { + auto local_range = std::size_t(WARP_SIZE); + auto num_quant_blocks = ky * (kx / QK8_1); + auto global_range = num_quant_blocks * local_range; + stream->parallel_for(sycl::nd_range<1>({ global_range }, { local_range }), + [=](sycl::nd_item<1> it) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + quantize_and_reorder_q8_1(x, vy, kx, kx_padded, it); + }); + } else { + const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE; + const sycl::range<3> num_blocks(1, ky, block_num_x); + int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE; + static_assert(QK8_1 % WARP_SIZE == 0); + const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE); + { + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); - stream->parallel_for( - sycl::nd_range<3>(num_blocks * block_size, block_size), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - quantize_q8_1(x, vy, kx, kx_padded, item_ct1); - }); + stream->parallel_for(sycl::nd_range<3>(num_blocks * block_size, block_size), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + quantize_q8_1(x, vy, kx, kx_padded, item_ct1); + }); + } } } @@ -1967,11 +2066,6 @@ catch (sycl::exception const &exc) { std::exit(1); } -static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - ggml_sycl_op_bin_bcast>(ctx, dst, dst->src[0], dst); -} - - inline void ggml_sycl_op_mul_mat_sycl( ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, @@ -1986,31 +2080,30 @@ inline void ggml_sycl_op_mul_mat_sycl( const int64_t ne00 = src0->ne[0]; const int64_t ne10 = src1->ne[0]; - + GGML_ASSERT(ne00 == ne10); const int64_t row_diff = row_high - row_low; int id; SYCL_CHECK( CHECK_TRY_ERROR(id = get_current_device_id())); -#if !GGML_SYCL_DNNL - const int64_t ne0 = dst->ne[0]; + + const int64_t ne0 = dst->ne[0]; // used by MKL only // the main device has a larger memory buffer to hold the results from all GPUs // ldc == nrows of the matrix that cuBLAS writes into - int ldc = id == ctx.device ? ne0 : row_diff; -#endif + int ldc = id == ctx.device ? ne0 : row_diff; // used by MKL only #ifdef GGML_SYCL_F16 bool use_fp16 = true; // TODO(Yu) SYCL capability check #else bool use_fp16 = false; #endif - if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && - use_fp16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && - dst->op_params[0] == GGML_PREC_DEFAULT) { - // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp16 path\n"); + if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && use_fp16 && ggml_is_contiguous(src0) && + row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { ggml_sycl_pool_alloc src0_as_f16(ctx.pool()); if (src0->type != GGML_TYPE_F16) { + scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2, + " : converting src0 to fp16"); const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type, dst); GGML_ASSERT(to_fp16_sycl != nullptr); size_t ne = row_diff*ne00; @@ -2023,6 +2116,8 @@ inline void ggml_sycl_op_mul_mat_sycl( ggml_sycl_pool_alloc src1_as_f16(ctx.pool()); if (src1->type != GGML_TYPE_F16) { + scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2, + " : converting src1 to fp16"); const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst); GGML_ASSERT(to_fp16_sycl != nullptr); size_t ne = src1_ncols*ne10; @@ -2032,39 +2127,46 @@ inline void ggml_sycl_op_mul_mat_sycl( const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16 ? (const sycl::half *)src1->data + src1_padded_row_size : src1_as_f16.get(); - ggml_sycl_pool_alloc dst_f16(ctx.pool(), row_diff * src1_ncols); -#if !GGML_SYCL_DNNL - const sycl::half alpha_f16 = 1.0f; - const sycl::half beta_f16 = 0.0f; - SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( - *stream, oneapi::math::transpose::trans, - oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10, - &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00, - src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16, - dst_f16.get(), dpct::library_data_t::real_half, ldc, - dpct::library_data_t::real_half))); - const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); - to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); -#else - DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr, - DnnlGemmWrapper::to_dt(), src0_ptr, DnnlGemmWrapper::to_dt(), - dst_f16.get(), DnnlGemmWrapper::to_dt(), stream); - const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); - to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream); +#if GGML_SYCL_DNNL + if (!g_ggml_sycl_disable_dnn) { + DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr, + DnnlGemmWrapper::to_dt(), src0_ptr, DnnlGemmWrapper::to_dt(), + dst_dd_i, DnnlGemmWrapper::to_dt(), stream); + } + else #endif - } - else { - // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n"); + { + ggml_sycl_pool_alloc dst_f16(ctx.pool(), row_diff * src1_ncols); + + const sycl::half alpha_f16 = 1.0f; + const sycl::half beta_f16 = 0.0f; + SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( + *stream, oneapi::math::transpose::trans, + oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10, + &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00, + src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16, + dst_f16.get(), dpct::library_data_t::real_half, ldc, + dpct::library_data_t::real_half))); + scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2, + " : converting dst to fp32"); + const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst); + to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); + } + } else { ggml_sycl_pool_alloc src0_ddq_as_f32(ctx.pool()); ggml_sycl_pool_alloc src1_ddq_as_f32(ctx.pool()); if (src0->type != GGML_TYPE_F32) { + scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2, + " : converting src0 to fp32"); const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type, dst); GGML_ASSERT(to_fp32_sycl != nullptr); src0_ddq_as_f32.alloc(row_diff*ne00); to_fp32_sycl(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream); } if (src1->type != GGML_TYPE_F32) { + scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2, + " : converting src1 to fp32"); const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type, dst); GGML_ASSERT(to_fp32_sycl != nullptr); src1_ddq_as_f32.alloc(src1_ncols*ne10); @@ -2073,18 +2175,22 @@ inline void ggml_sycl_op_mul_mat_sycl( const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get(); const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get(); -#if !GGML_SYCL_DNNL - const float alpha = 1.0f; - const float beta = 0.0f; - SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm( - get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff, - src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, - dpct::get_value(&beta, *stream), dst_dd_i, ldc))); -#else - DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, - DnnlGemmWrapper::to_dt(), src0_ddf_i, DnnlGemmWrapper::to_dt(), - dst_dd_i, DnnlGemmWrapper::to_dt(), stream); +#if GGML_SYCL_DNNL + if (!g_ggml_sycl_disable_dnn) { + DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i, + DnnlGemmWrapper::to_dt(), src0_ddf_i, DnnlGemmWrapper::to_dt(), + dst_dd_i, DnnlGemmWrapper::to_dt(), stream); + } + else #endif + { + const float alpha = 1.0f; + const float beta = 0.0f; + SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm( + get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff, + src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, + dpct::get_value(&beta, *stream), dst_dd_i, ldc))); + } } GGML_UNUSED(dst); GGML_UNUSED(src1_ddq_i); @@ -2096,8 +2202,7 @@ catch (sycl::exception const &exc) { std::exit(1); } -static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); dpct::queue_ptr main_stream = ctx.stream(); @@ -2149,8 +2254,7 @@ inline void ggml_sycl_op_sum(ggml_backend_sycl_context & ctx, ggml_tensor *dst) sum_rows_f32_sycl(src0_dd, dst_dd, ne, 1, main_stream); } -inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); dpct::queue_ptr main_stream = ctx.stream(); @@ -2181,8 +2285,7 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * argsort_f32_i32_sycl(src0_dd, (int *) dst_dd, ncols, nrows, order, main_stream); } -inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_I32); @@ -2197,8 +2300,7 @@ inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor *ds argmax_f32_i32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream); } -inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx,ggml_tensor *dst) { - +inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); dpct::queue_ptr main_stream = ctx.stream(); @@ -2215,8 +2317,7 @@ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx,ggml_tens diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream); } -inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - +inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); dpct::queue_ptr main_stream = ctx.stream(); @@ -2403,7 +2504,10 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs); if (src1_on_device && src1_is_contiguous) { - quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream); + bool reorder_q8_tensor = src0->extra && ((ggml_tensor_extra_gpu *)src0->extra)->optimized_feature.reorder; + scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst, + /*num_src=*/2, " : converting src1 to Q8_1"); + quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, reorder_q8_tensor, stream); /* DPCT1010:90: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You need to @@ -2507,7 +2611,9 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten } if (convert_src1_to_q8_1 && !src1_is_contiguous) { - quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream); + scope_op_debug_print scope_dbg_print(__func__, "/quantize_row_q8_1_sycl", dst, + /*num_src=*/2, " : converting src1 to Q8_1"); + quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, false, stream); /* DPCT1010:92: SYCL uses exceptions to report errors and does not use the error codes. The call was replaced with 0. You @@ -2600,40 +2706,29 @@ catch (sycl::exception const &exc) { } -static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); - ggml_sycl_op_repeat(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); -} - static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); ggml_sycl_op_get_rows(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_norm(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_rms_norm(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } static void ggml_sycl_l2_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_l2_norm(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_SYCL_DEBUG("call %s\n", __func__); + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_group_norm(ctx, dst); - GGML_SYCL_DEBUG("call %s done\n", __func__); } static void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, @@ -2704,139 +2799,182 @@ catch (sycl::exception const &exc) { std::exit(1); } -static void k_compute_batched_ptrs(const sycl::half *src0_as_f16, - const sycl::half *src1_as_f16, char *dst, - const void **ptrs_src, void **ptrs_dst, - int64_t ne12, int64_t ne13, int64_t ne23, - size_t nb02, size_t nb03, size_t nb12, - size_t nb13, size_t nbd2, size_t nbd3, - int64_t r2, int64_t r3, - const sycl::nd_item<3> &item_ct1) { - int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + - item_ct1.get_local_id(2); - int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) + - item_ct1.get_local_id(1); +static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst, + const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23, + size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3, + int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) { + const int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2); + const int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); if (i13 >= ne13 || i12 >= ne12) { return; } - int64_t i03 = i13 / r3; - int64_t i02 = i12 / r2; + const int64_t i03 = i13 / r3; + const int64_t i02 = i12 / r2; - ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03; - ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13; - ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3; + const uint8_t * src0_bytes = reinterpret_cast(src0_as_f16); + const uint8_t * src1_bytes = reinterpret_cast(src1_as_f16); + uint8_t * dst_bytes = static_cast(dst); + + ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03; + ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13; + ptrs_dst[0 * ne23 + i12 + i13 * ne12] = dst_bytes + i12 * nbd2 + i13 * nbd3; } -static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, - const ggml_tensor *src0, - const ggml_tensor *src1, - ggml_tensor *dst) try { +static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, + const ggml_tensor * src1, ggml_tensor * dst) try { GGML_ASSERT(!ggml_is_transposed(src0)); GGML_ASSERT(!ggml_is_transposed(src1)); GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer)); GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_TENSOR_BINARY_OP_LOCALS + // TODO: see https://github.com/ggml-org/llama.cpp/pull/13155 + // Batched mul_mat requires a rewrite to support both oneDNN and non-contiguous dst + GGML_ASSERT(ggml_is_contiguous(dst)); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - queue_ptr main_stream = ctx.stream();; + queue_ptr queue = ctx.stream(); - void * src0_ddq = src0->data; - sycl::half *src0_as_f16 = (sycl::half *)src0_ddq; - float * src1_ddf = (float *) src1->data; - float * dst_ddf = (float *) dst->data; + dpct::has_capability_or_fail(queue->get_device(), { sycl::aspect::fp16 }); + + const sycl::half * src0_f16 = static_cast(src0->data); + float * dst_ddf = static_cast(dst->data); + + const sycl::half * src1_f16 = static_cast(src1->data); + const size_t type_size_src1 = ggml_type_size(src1->type); + GGML_ASSERT(nb10 == type_size_src1); + + // SRC1 strides + int64_t s11 = nb11 / type_size_src1; + int64_t s12 = nb12 / type_size_src1; + int64_t s13 = nb13 / type_size_src1; + ggml_sycl_pool_alloc src1_f16_alloc(ctx.pool()); // convert src1 to fp16 - ggml_sycl_pool_alloc src1_f16_alloc(ctx.pool()); if (src1->type != GGML_TYPE_F16) { - const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst); + scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_nc_sycl", dst, /*num_src=*/2, + " : converting src1 to fp16"); + const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type); + GGML_ASSERT(to_fp16_nc_sycl != nullptr); const int64_t ne_src1 = ggml_nelements(src1); src1_f16_alloc.alloc(ne_src1); - GGML_ASSERT(to_fp16_sycl != nullptr); - to_fp16_sycl(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream); + to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue); + + src1_f16 = src1_f16_alloc.get(); + s11 = ne10; + s12 = ne11 * s11; + s13 = ne12 * s12; } - sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf - : src1_f16_alloc.get(); - char * dst_t; + ggml_sycl_pool_alloc dst_f16(ctx.pool()); - dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float; - dpct::library_data_t cu_data_type = dpct::library_data_t::real_float; + dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float; + dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float; // dst strides size_t nbd2 = dst->nb[2]; size_t nbd3 = dst->nb[3]; const float alpha_f32 = 1.0f; - const float beta_f32 = 0.0f; + const float beta_f32 = 0.0f; const void * alpha = &alpha_f32; const void * beta = &beta_f32; - dst_t = (char *) dst_ddf; - GGML_ASSERT(ne12 % ne02 == 0); GGML_ASSERT(ne13 % ne03 == 0); + GGML_ASSERT(ne01 == static_cast(nb1/nb0)); + GGML_ASSERT(ne10 == ne00); // broadcast factors - const int64_t r2 = ne12/ne02; - const int64_t r3 = ne13/ne03; + const int64_t r2 = ne12 / ne02; + const int64_t r3 = ne13 / ne03; - if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { - // there is no broadcast and src0, src1 are contiguous across dims 2, 3 - SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( - *main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, - (const char *) src0_as_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00, - (const char *) src1_f16, dpct::library_data_t::real_half, nb11 / nb10, nb12 / nb10, beta, (char *) dst_t, - cu_data_type, ne01, nb2 / nb0, ne12 * ne13, cu_compute_type))); - } else { - const int ne23 = ne12*ne13; +#if GGML_SYCL_DNNL + if (!g_ggml_sycl_disable_dnn) { + auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12] + (const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) { - ggml_sycl_pool_alloc ptrs_src(ctx.pool(), 2*ne23); - ggml_sycl_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23); - ggml_sycl_pool_alloc> matrix_info(ctx.host_pool(), 1); + DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10, + src1, DnnlGemmWrapper::to_dt(), s11, 1, s12, + src0, DnnlGemmWrapper::to_dt(), 1, nb01/nb00, nb02/nb00, + dst, DnnlGemmWrapper::to_dt(), queue, batches_a, batches_b); + }; - sycl::range<3> block_dims(1, ne12, ne13); - /* - DPCT1049:47: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ - { - dpct::has_capability_or_fail(main_stream->get_device(), - {sycl::aspect::fp16}); - - main_stream->submit([&](sycl::handler &cgh) { - const void **ptrs_src_get = ptrs_src.get(); - void **ptrs_dst_get = ptrs_dst.get(); - size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : nb12 / 2; - size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : nb13 / 2; - cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - k_compute_batched_ptrs( - src0_as_f16, src1_f16, - dst_t, ptrs_src_get, - ptrs_dst_get, ne12, ne13, ne23, - nb02, nb03, nb12_scaled, nb13_scaled, - nbd2, nbd3, r2, r3, item_ct1); - }); - }); + if (r2 == 1 && r3 == 1) { + if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { + dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03); + } + else { + for (int64_t ie03 = 0; ie03 < ne03; ++ie03) { + const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes + const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13; + float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float)); + dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02); + } + } + } else { + // iterate over batches from smaller set of matrices (matrix 0) + for (int64_t ie02 = 0; ie02 < ne02; ++ie02) { + for (int64_t ie03 = 0; ie03 < ne03; ++ie03) { + const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half)); + const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3; + float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float)); + dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1); + } + } } - SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( - *main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, - (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00, - (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta, - (void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get()))); } + else +#endif + { + if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { + // there is no broadcast and src0, src1 are contiguous across dims 2, 3 + SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans, + oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, + src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00, + src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf, + mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type))); + } else { + const int ne23 = ne12 * ne13; + + ggml_sycl_pool_alloc ptrs_src(ctx.pool(), 2 * ne23); + ggml_sycl_pool_alloc ptrs_dst(ctx.pool(), 1 * ne23); + ggml_sycl_pool_alloc> matrix_info(ctx.host_pool(), 1); + + sycl::range<3> block_dims(1, ne12, ne13); + queue->submit([&](sycl::handler & cgh) { + const void ** ptrs_src_get = ptrs_src.get(); + void ** ptrs_dst_get = ptrs_dst.get(); + size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half); + size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half); + cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02, + nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1); + }); + }); + + SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( + *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, + (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00, + (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta, + (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get()))); + } + } +} catch (const sycl::exception & exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl; + std::exit(1); } -catch (sycl::exception const &exc) { - std::cerr << exc.what() << "Exception caught at file:" << __FILE__ - << ", line:" << __LINE__ << std::endl; - std::exit(1); -} + +enum class mul_mat_algo { + DMMV = 0, + MMVQ = 1, + MUL_MAT_SYCL = 2, +}; inline bool ggml_sycl_supports_mmq(enum ggml_type type) { // TODO: accuracy issues in MMQ @@ -2844,6 +2982,38 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) { return false; } +inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return true; + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: + return !g_ggml_sycl_prioritize_dmmv; + default: + return false; + } +} + +inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return true; + default: + return false; + } +} + +inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: + return true; + default: + return false; + } +} + static bool ggml_sycl_supports_dmmv(enum ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: @@ -2863,13 +3033,190 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) { } } -static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset, + dpct::queue_ptr stream) { + auto * tmp_buf = sycl::malloc_shared(size, *stream); + SYCL_CHECK( + CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size) + .wait())); + GGML_ASSERT((size % sizeof(block_q4_0) == 0)); + GGML_ASSERT((offset % sizeof(block_q4_0) == 0)); + int offset_blks = offset / sizeof(block_q4_0); + auto qs_ptr = data_device + offset_blks * QK4_0 / 2; + auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks; + stream->parallel_for( + size / sizeof(block_q4_0), + [=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + const block_q4_0* x = (const block_q4_0*)tmp_buf; + const int ib = i; + + for (int j = 0; j < QK4_0/2; j ++) + { + *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j]; + } + *(d_ptr + ib) = x[ib].d; + }).wait_and_throw(); + + sycl::free(tmp_buf, *stream); +} + +static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { + GGML_ASSERT(size % sizeof(block_q4_K) == 0); + GGML_ASSERT(offset % sizeof(block_q4_K) == 0); + + const int nblocks = size / sizeof(block_q4_K); + + auto * tmp_buf = sycl::malloc_shared(size, *stream); + SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait())); + + auto * qs_ptr = data_device; + auto * scales_ptr = qs_ptr + QK_K / 2 * nblocks; + auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks); + + stream->parallel_for(nblocks, [=](auto i) { + const block_q4_K * x = (const block_q4_K *) tmp_buf; + const int ib = i; + + for (int j = 0; j < QK_K / 2; ++j) { + qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j]; + } + + for (int j = 0; j < K_SCALE_SIZE; ++j) { + scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j]; + } + + dm_ptr[ib] = x[ib].dm; + }).wait_and_throw(); + + sycl::free(tmp_buf, *stream); +} + +static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { + GGML_ASSERT(size % sizeof(block_q6_K) == 0); + GGML_ASSERT(offset % sizeof(block_q6_K) == 0); + + const int nblocks = size / sizeof(block_q6_K); + + auto * tmp_buf = sycl::malloc_shared(size, *stream); + SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait())); + + auto * ql_ptr = data_device; + auto * qh_ptr = ql_ptr + (QK_K / 2) * nblocks; + auto * scales_ptr = qh_ptr + (QK_K / 4) * nblocks; + sycl::half * dm_ptr = (sycl::half *) (scales_ptr + (QK_K / 16) * nblocks); + + stream + ->parallel_for(nblocks, + [=](auto i) { + const block_q6_K * x = (const block_q6_K *) tmp_buf; + const int ib = i; + + const uint8_t * ql = x[ib].ql; + const uint8_t * qh = x[ib].qh; + uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib; + uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib; + uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib; + + for (int j = 0; j < QK_K / 2; ++j) { + base_ql_ptr[j] = ql[j]; + } + for (int j = 0; j < QK_K / 4; ++j) { + base_qh_ptr[j] = qh[j]; + } + + for (int j = 0; j < QK_K / 16; ++j) { + base_scales_ptr[j] = x[ib].scales[j]; + } + + dm_ptr[ib] = x[ib].d; + }) + .wait_and_throw(); + + sycl::free(tmp_buf, *stream); +} + +static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { + uint8_t * data_device = (uint8_t *) src0->data; + size_t ncols = src0->ne[0]; + size_t nrows = src0->ne[1]; + size_t size = ggml_nbytes(src0); + + switch (src0->type) { + case GGML_TYPE_Q4_0: + reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream); + break; + case GGML_TYPE_Q4_K: + reorder_qw_q4_k(data_device, size, 0, stream); + break; + case GGML_TYPE_Q6_K: + reorder_qw_q6_k(data_device, size, 0, stream); + break; + default: + GGML_ABORT("reorder_qw() called with unsupported type"); + break; + } +} + +static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_tensor * dst) { + return !g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT + ctx.opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf. + dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases. + dst->src[1]->ne[1]==1 && dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1; +} + +static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */, + ggml_tensor * dst, mul_mat_algo mm_algorithm) { + if (!should_reorder_tensor(*ctx, dst)) { + return; + } + + ggml_tensor_extra_gpu * extra = static_cast(src0->extra); + if (!extra || extra->optimized_feature.reorder) { + return; // Skip permutations and already reordered tensors + } + + switch (mm_algorithm) { + case mul_mat_algo::DMMV: + if (!ggml_sycl_supports_reorder_dmmv(src0->type)) { + return; + } + break; + case mul_mat_algo::MMVQ: + if (!ggml_sycl_supports_reorder_mmvq(src0->type)) { + return; + } + break; + case mul_mat_algo::MUL_MAT_SYCL: + if (!ggml_sycl_supports_reorder_mul_mat_sycl(src0->type)) { + return; + } + break; + } + + reorder_qw(src0, ctx->stream()); + extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering +} + + +static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && + src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1; +} + +static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + return ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && + src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; +} + +static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer); int64_t min_compute_capability = INT_MAX; if (split) { - ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context; + ggml_backend_sycl_split_buffer_type_context * buft_ctx = + (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context; auto & tensor_split = buft_ctx->tensor_split; for (int id = 0; id < ggml_sycl_info().device_count; ++id) { // skip devices that are not going to do any work: @@ -2882,17 +3229,13 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor } } } else { - min_compute_capability = ggml_sycl_info().devices[ctx.device].cc; + min_compute_capability = ggml_sycl_info().devices[ctx.device].cc; } // check data types and tensor shapes for custom matrix multiplication kernels: - bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type) - && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1; + bool use_dequantize_mul_mat_vec = can_use_dequantize_mul_mat_vec(src0, src1, dst); - bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) - && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; + bool use_mul_mat_vec_q = can_use_mul_mat_vec_q(src0, src1, dst); bool use_mul_mat_q = ggml_sycl_supports_mmq(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; @@ -2904,9 +3247,15 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE); #endif // SYCL_USE_XMX + // mmvq path is faster in the CUDA backend. - if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda) + if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda + // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization + // is enabled takes precedence over DMMV, the current if-else implementation + // requires disabling DMMV if both conditions are met + || (should_reorder_tensor(ctx, dst) && ggml_sycl_supports_reorder_mmvq(src0->type)))) { use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; + } if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { // TODO: Refactor and cleanup of mul mat dispatching. @@ -2918,21 +3267,26 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor // The kernel from the if path is faster for that specific case, but does not support all mul mats. ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); } - } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { + } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { // KQV single-batch ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst); } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // KQ + KQV multi-batch ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst); } else if (use_dequantize_mul_mat_vec) { - ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false); - // save_tensor_txt("1/dst_1.txt", (float*) dst->data, src0->ne[1], sizeof(float), ctx.stream()); + constexpr bool convert_src1_to_q8_1 = false; + opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::DMMV); + ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1); } else if (use_mul_mat_vec_q) { - ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true); + constexpr bool convert_src1_to_q8_1 = true; + opt_for_reorder(&ctx, src0, src1, dst, mul_mat_algo::MMVQ); + ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1); } else if (use_mul_mat_q) { - ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true); + constexpr bool convert_src1_to_q8_1 = true; + ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1); } else { - ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false); + constexpr bool convert_src1_to_q8_1 = false; + ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1); } } @@ -3004,6 +3358,7 @@ __dpct_inline__ static void k_copy_dst_from_contiguous( static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, ggml_tensor *dst) try { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3); const ggml_tensor *src0 = dst->src[0]; const ggml_tensor *src1 = dst->src[1]; GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers"); @@ -3172,42 +3527,45 @@ catch (sycl::exception const &exc) { } static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_scale(ctx, dst); } static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_diag_mask_inf(ctx, dst); } -static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented - ggml_sycl_op_rope(ctx, dst); -} - static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_pool2d(ctx, dst); } static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); ggml_sycl_op_im2col(ctx, dst); } static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); GGML_ASSERT(ggml_is_contiguous(dst->src[0])); ggml_sycl_op_sum(ctx, dst); } static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); GGML_ASSERT(ggml_is_contiguous(dst->src[0])); ggml_sycl_op_sum_rows(ctx, dst); } static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); GGML_ASSERT(ggml_is_contiguous(dst->src[0])); ggml_sycl_op_argsort(ctx, dst); } static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); GGML_ASSERT(ggml_is_contiguous(dst->src[0])); ggml_sycl_op_argmax(ctx, dst); } @@ -3293,6 +3651,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_UNARY_OP_GELU_QUICK: ggml_sycl_gelu_quick(ctx, dst); break; + case GGML_UNARY_OP_GELU_ERF: + ggml_sycl_gelu_erf(ctx, dst); + break; case GGML_UNARY_OP_TANH: ggml_sycl_tanh(ctx, dst); break; @@ -3311,6 +3672,15 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_UNARY_OP_EXP: ggml_sycl_exp(ctx, dst); break; + case GGML_UNARY_OP_SGN: + ggml_sycl_sgn(ctx, dst); + break; + case GGML_UNARY_OP_ABS: + ggml_sycl_abs(ctx, dst); + break; + case GGML_UNARY_OP_ELU: + ggml_sycl_elu(ctx, dst); + break; default: return false; } @@ -3492,6 +3862,9 @@ static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend, ggml_tensor *tensor, const void *data, size_t offset, size_t size) try { + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str()); + GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset); ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; @@ -3510,13 +3883,16 @@ static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend, const ggml_tensor *tensor, void *data, size_t offset, size_t size) try { + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": tensor", tensor).c_str()); + GGML_SYCL_DEBUG(" size=%zu offset=%zu\n", size, offset); ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type"); const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0); SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy( - data, (const char *)tensor->data + offset, size).wait())); + data, (const char *)tensor->data + offset, size))); } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ @@ -3528,7 +3904,13 @@ static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor *src, ggml_tensor *dst) try { ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; - if (dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && ggml_backend_buffer_is_sycl(src->buffer)) { + bool is_cpy_supported = dst->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && + ggml_backend_buffer_is_sycl(src->buffer); + GGML_SYCL_DEBUG("[SYCL] call %s", __func__); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(": dst", dst).c_str()); + GGML_SYCL_DEBUG("%s", debug_get_tensor_str(" src", src).c_str()); + GGML_SYCL_DEBUG(" is_cpy_supported=%d\n", is_cpy_supported); + if (is_cpy_supported) { /* DPCT1009:215: SYCL uses exceptions to report errors and does not use the error codes. The original code was commented out and a warning string @@ -3536,7 +3918,7 @@ static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend, */ const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0); SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy( - dst->data, src->data, ggml_nbytes(dst)).wait())); + dst->data, src->data, ggml_nbytes(dst)))); return true; } @@ -3549,6 +3931,7 @@ catch (sycl::exception const &exc) { } static void ggml_backend_sycl_synchronize(ggml_backend_t backend) try { + GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__); ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context; const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0); SYCL_CHECK(CHECK_TRY_ERROR((stream)->wait())); @@ -3561,71 +3944,8 @@ catch (sycl::exception const &exc) { std::exit(1); } -static void reorder_qw(char *data_device, const int ncols, const int nrows, - size_t size, size_t offset, dpct::queue_ptr stream) { - auto tmp_buf = sycl::malloc_shared(size, *stream); - SYCL_CHECK( - CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size) - .wait())); - GGML_ASSERT((size % sizeof(block_q4_0) == 0)); - GGML_ASSERT((offset % sizeof(block_q4_0) == 0)); - int offset_blks = offset / sizeof(block_q4_0); - auto qs_ptr = (uint8_t*)data_device + offset_blks * QK4_0 / 2;; - auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows / 2) + offset_blks; - - stream->parallel_for( - size / sizeof(block_q4_0), - [=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - const block_q4_0* x = (const block_q4_0*)tmp_buf; - const int ib = i; - - for (int j = 0; j < QK4_0/2; j ++) - { - *(qs_ptr + ib * QK4_0 / 2 + j) = x[ib].qs[j]; - } - *(d_ptr + ib) = x[ib].d; - }); - - sycl::free(tmp_buf, *stream); -} - -static void reorder_qw(ggml_tensor * src0, dpct::queue_ptr stream) { - char*data_device = (char*)src0->data; - size_t ncols = src0->ne[0]; - size_t nrows = src0->ne[1]; - size_t size = ggml_nbytes(src0); - - reorder_qw(data_device, ncols, nrows, size, 0, stream); -} - -static void opt_for_reorder(ggml_tensor * dst, dpct::queue_ptr stream) { - ggml_tensor *src0 = dst->src[0]; - ggml_tensor *src1 = dst->src[1]; - - if (dst->op == GGML_OP_MUL_MAT && src0->type == GGML_TYPE_Q4_0 && - src1->ne[2]==1 && src1->ne[3]==1) { - reorder_qw(src0, stream); - ggml_tensor_extra_gpu* extra = (ggml_tensor_extra_gpu*)src0->extra; - GGML_ASSERT(extra); - extra->optimized_feature.reorder = true; //used to decode/dequan in next steps. - } -} - -static void optimize_graph_once(ggml_cgraph * cgraph, ggml_backend_sycl_context * ctx) { - dpct::queue_ptr stream = ctx->stream(); - if (ctx->optimized_graph) { - return; - } - ctx->optimized_graph = true; - - for (int i = 0; i < cgraph->n_nodes; i++) { - if (ctx->opt_feature.reorder) opt_for_reorder(cgraph->nodes[i], stream); - } -} - static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * sycl_ctx, ggml_cgraph * cgraph) { ggml_sycl_set_main_device(sycl_ctx->device); - if (!g_ggml_sycl_disable_optimize) optimize_graph_once(cgraph, sycl_ctx); for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -3648,11 +3968,43 @@ static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * syc } } +#ifdef GGML_SYCL_GRAPH +static bool check_graph_compatibility(ggml_cgraph * cgraph) { + if (ggml_sycl_info().device_count > 1) { + // A sycl_ex::command_graph object can only be created for a single device + GGML_LOG_INFO("%s: disabling SYCL graphs due to multiple devices\n", __func__); + return false; + } + + for (int i = 0; i < cgraph->n_nodes; i++) { + const ggml_op node_op = cgraph->nodes[i]->op; + switch (node_op) { + default: + break; + case GGML_OP_CONCAT: + // ggml_sycl_op_concat() does a blocking host wait after memcpy operations, + // but wait() can't be called on the events returned by a queue recording + // to a graph. + [[fallthrough]]; + case GGML_OP_MUL_MAT_ID: + // ggml_sycl_mul_mat_id() does a blocking host wait on the sycl queue after + // submitting a memcpy operation, but wait() can't be called on a queue that + // is recording to a graph. + GGML_LOG_INFO("%s: disabling SYCL graphs due to unsupported node type %s\n", __func__, + ggml_op_name(node_op)); + return false; + } + } + return true; +} +#endif + static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { auto * sycl_ctx = static_cast(backend->context); #ifdef GGML_SYCL_GRAPH - if (!g_ggml_sycl_disable_graph) { + bool use_sycl_graph = !g_ggml_sycl_disable_graph && check_graph_compatibility(cgraph); + if (use_sycl_graph) { const bool graph_support = dpct::get_device(sycl_ctx->device).has(sycl::aspect::ext_oneapi_limited_graph); if (!graph_support) { GGML_SYCL_DEBUG("[SYCL-GRAPH] can not use graphs on device:%d\n", sycl_ctx->device); @@ -3660,7 +4012,8 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_ return GGML_STATUS_SUCCESS; } - sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream())); + sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}}); + model_sycl_graph.begin_recording(*(sycl_ctx->stream())); ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph); model_sycl_graph.end_recording(); @@ -3712,7 +4065,7 @@ catch (sycl::exception const &exc) } static void ggml_backend_sycl_event_wait(ggml_backend_t backend, ggml_backend_event_t event) try { - + GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__); sycl::event* sycl_event = static_cast(event->context); if (ggml_backend_is_sycl(backend)) { @@ -3854,8 +4207,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSWISH: case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SGN: + case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_ELU: #if defined (GGML_SYCL_F16) return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type); #else @@ -3916,6 +4273,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g { ggml_type src0_type = op->src[0]->type; ggml_type src1_type = op->src[1]->type; + if (src0_type == src1_type && (ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) && src0_type != GGML_TYPE_BF16) { + return true; + } if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { return true; } @@ -3961,6 +4321,21 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) { return true; } + if(src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_Q8_0) { + return true; + } + if(src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_Q5_0) { + return true; + } + if(src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_Q5_1) { + return true; + } + if(src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_Q4_0) { + return true; + } + if(src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_Q4_1) { + return true; + } return false; } case GGML_OP_CONCAT: @@ -3972,7 +4347,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_ARGMAX: case GGML_OP_NONE: case GGML_OP_RESHAPE: - case GGML_OP_REPEAT: case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: @@ -3982,7 +4356,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: - return (op->src[0]->type == GGML_TYPE_F32); + case GGML_OP_REPEAT: + return true; case GGML_OP_SQR: case GGML_OP_SQRT: case GGML_OP_SIN: @@ -3996,6 +4371,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g #endif case GGML_OP_NORM: case GGML_OP_RMS_NORM: + return true; case GGML_OP_L2_NORM: case GGML_OP_GROUP_NORM: return ggml_is_contiguous(op->src[0]); @@ -4007,19 +4383,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_SOFT_MAX: return true; case GGML_OP_ROPE: - { - const int mode = ((const int32_t *) op->op_params)[2]; - if (mode & GGML_ROPE_TYPE_MROPE) { - return false; - } - if (mode & GGML_ROPE_TYPE_VISION) { - return false; - } - return ggml_is_contiguous(op->src[0]); - } case GGML_OP_IM2COL: - // TODO: add support for the new F32 operations - return op->src[0]->type == GGML_TYPE_F16; + return true; case GGML_OP_UPSCALE: return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; case GGML_OP_POOL_2D: @@ -4107,6 +4472,7 @@ static void ggml_backend_sycl_device_event_free(ggml_backend_dev_t dev, ggml_bac static void ggml_backend_sycl_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) try { GGML_UNUSED(dev); + GGML_SYCL_DEBUG("[SYCL] call %s\n", __func__); sycl::event *sycl_event = static_cast(event->context); SYCL_CHECK(CHECK_TRY_ERROR(sycl_event->wait())); diff --git a/ggml/src/ggml-sycl/gla.cpp b/ggml/src/ggml-sycl/gla.cpp index eedb47486..879184fdd 100644 --- a/ggml/src/ggml-sycl/gla.cpp +++ b/ggml/src/ggml-sycl/gla.cpp @@ -76,6 +76,7 @@ static void gated_linear_attn_f32_kernel(const dpct::queue_ptr stream, u_int B, } void ggml_sycl_op_gated_linear_attn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/5); const float * k_d = static_cast(dst->src[0]->data); const float * v_d = static_cast(dst->src[1]->data); const float * r_d = static_cast(dst->src[2]->data); diff --git a/ggml/src/ggml-sycl/im2col.cpp b/ggml/src/ggml-sycl/im2col.cpp index 009b42035..aa19c2527 100644 --- a/ggml/src/ggml-sycl/im2col.cpp +++ b/ggml/src/ggml-sycl/im2col.cpp @@ -12,110 +12,125 @@ #include "im2col.hpp" +#include +#include // For std::is_same_v + +#include "ggml.h" + template -static void im2col_kernel( - const float *x, T *dst, int64_t batch_offset, int64_t offset_delta, - int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, - int64_t pelements, int64_t CHW, int s0, int s1, int p0, int p1, int d0, int d1, - const sycl::nd_item<3> &item_ct1) { +static void im2col_kernel(const float * x, T * dst, int64_t batch_offset, int64_t offset_delta, int64_t IC, int64_t IW, + int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW, + int s0, int s1, int p0, int p1, int d0, int d1, const sycl::nd_item<3> & item_ct1) { const int64_t work_group_size = item_ct1.get_local_range(2); - const int64_t global_id = item_ct1.get_local_id(2) + work_group_size * item_ct1.get_group(2); + const int64_t global_id = item_ct1.get_local_id(2) + (work_group_size * item_ct1.get_group(2)); // make each work-item deal with more elements since sycl global range can not exceed max int - for (int64_t i = global_id; i < pelements; i += work_group_size * item_ct1.get_group_range(2)) { - + for (int64_t i = global_id; i < pelements; i += (work_group_size * item_ct1.get_group_range(2))) { const int64_t ksize = OW * (KH > 1 ? KW : 1); - const int64_t kx = i / ksize; - const int64_t kd = kx * ksize; - const int64_t ky = (i - kd) / OW; - const int64_t ix = i % OW; + const int64_t kx = i / ksize; + const int64_t kd = kx * ksize; + const int64_t ky = (i - kd) / OW; + const int64_t ix = i % OW; - const int64_t oh = item_ct1.get_group(1); - const int64_t batch = item_ct1.get_group(0) / IC; - const int64_t ic = item_ct1.get_group(0) % IC; + const int64_t oh = item_ct1.get_group(1); + const int64_t batch = item_ct1.get_group(0) / IC; + const int64_t ic = item_ct1.get_group(0) % IC; - const int64_t iiw = ix * s0 + kx * d0 - p0; - const int64_t iih = oh * s1 + ky * d1 - p1; + const int64_t iiw = (ix * s0) + (kx * d0) - p0; + const int64_t iih = (oh * s1) + (ky * d1) - p1; - const int64_t offset_dst = - ((batch * OH + oh) * OW + ix) * CHW + - (ic * (KW * KH) + ky * KW + kx); + const int64_t offset_dst = (((batch * OH + oh) * OW + ix) * CHW) + (ic * (KW * KH) + ky * KW + kx); - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = - sycl::vec(0.0f) - .convert()[0]; - } else { - const int64_t offset_src = ic * offset_delta + batch * batch_offset; - dst[offset_dst] = - sycl::vec(x[offset_src + iih * IW + iiw]) - .convert()[0]; + const int64_t offset_src_base = (ic * offset_delta) + (batch * batch_offset); + const int64_t offset_src = offset_src_base + (iih * IW) + iiw; + + const bool out_of_bounds = (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW); + const float src_val = out_of_bounds ? 0.0f : x[offset_src]; + + if constexpr (std::is_same_v) { + dst[offset_dst] = sycl::half(src_val); + } else if constexpr (std::is_same_v) { + dst[offset_dst] = src_val; } } } template -static void im2col_sycl( - const float *x, T *dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, - int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta, - int s0, int s1, int p0, int p1, int d0, int d1, - queue_ptr stream) { +static void im2col_sycl_internal(const float * x, T * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, + int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta, + int s0, int s1, int p0, int p1, int d0, int d1, queue_ptr stream) { const int64_t parallel_elements = OW * KW * KH; - const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE; + const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE; // decrease global range when it exceeds the max int int64_t local_size = downsample_sycl_global_range(batch * IC * OH * num_blocks, SYCL_IM2COL_BLOCK_SIZE); + sycl::range<3> block_nums(batch * IC, OH, num_blocks); sycl::range<3> local_range(1, 1, local_size); - { - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); + const int64_t CHW = IC * KH * KW; - stream->parallel_for( - sycl::nd_range<3>(block_nums * local_range, local_range), - [=](sycl::nd_item<3> item_ct1) { - im2col_kernel(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, - parallel_elements, (IC * KH * KW), s0, s1, p0, - p1, d0, d1, item_ct1); - }); - } + stream->parallel_for(sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) { + im2col_kernel(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, CHW, s0, s1, + p0, p1, d0, d1, item_ct1); + }); } -void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { +static void im2col_sycl_f16(const float * x, sycl::half * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, + int64_t KW, int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, + int64_t offset_delta, int s0, int s1, int p0, int p1, int d0, int d1, queue_ptr stream) { + if (!stream->get_device().has(sycl::aspect::fp16)) { + throw sycl::exception(sycl::make_error_code(sycl::errc::kernel_not_supported), + "Device does not support half precision (fp16) operations!"); + } + im2col_sycl_internal(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, + p1, d0, d1, stream); +} + +static void im2col_sycl_f32(const float * x, float * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, + int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta, int s0, + int s1, int p0, int p1, int d0, int d1, queue_ptr stream) { + im2col_sycl_internal(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, + d0, d1, stream); +} + +void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); - const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; - const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; - const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; + const int32_t s0 = ((const int32_t *) (dst->op_params))[0]; + const int32_t s1 = ((const int32_t *) (dst->op_params))[1]; + const int32_t p0 = ((const int32_t *) (dst->op_params))[2]; + const int32_t p1 = ((const int32_t *) (dst->op_params))[3]; + const int32_t d0 = ((const int32_t *) (dst->op_params))[4]; + const int32_t d1 = ((const int32_t *) (dst->op_params))[5]; - const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1; + const bool is_2D = ((const int32_t *) (dst->op_params))[6] == 1; const int64_t IC = src1->ne[is_2D ? 2 : 1]; const int64_t IH = is_2D ? src1->ne[1] : 1; - const int64_t IW = src1->ne[0]; + const int64_t IW = src1->ne[0]; const int64_t KH = is_2D ? src0->ne[1] : 1; - const int64_t KW = src0->ne[0]; + const int64_t KW = src0->ne[0]; const int64_t OH = is_2D ? dst->ne[2] : 1; - const int64_t OW = dst->ne[1]; + const int64_t OW = dst->ne[1]; - const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 - const int64_t batch = src1->ne[3]; - const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32 + const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / sizeof(float); + const int64_t batch = src1->ne[is_2D ? 3 : 2]; + const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / sizeof(float); + + queue_ptr stream = ctx.stream(); if (dst->type == GGML_TYPE_F16) { - im2col_sycl((const float *) src1->data, (sycl::half *)dst->data, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, ctx.stream()); + im2col_sycl_f16((const float *) src1->data, (sycl::half *) dst->data, IW, IH, OW, OH, KW, KH, IC, batch, + batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream); } else { - im2col_sycl((const float *) src1->data, (float *)dst->data, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, ctx.stream()); + im2col_sycl_f32((const float *) src1->data, (float *) dst->data, IW, IH, OW, OH, KW, KH, IC, batch, + batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream); } } diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 1b92ba2d6..5b7f06407 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -1,6 +1,60 @@ #include "mmvq.hpp" + +#include "ggml.h" +#include "common.hpp" +#include "quants.hpp" #include "vecdotq.hpp" -#include + +template +static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols, const int nrows, const sycl::nd_item<3> & nd_item) { + using block_type = ggml_sycl_reordered::block_q_t; + using block_traits = typename block_type::traits; + + const auto sg = nd_item.get_sub_group(); + const int sg_range = sg.get_group_linear_range(); + const int workgroup_id = nd_item.get_group_linear_id(); + const int sg_id = sg.get_group_linear_id(); + const int row = workgroup_id * sg_range + sg_id; + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / block_traits::qk; + constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi); + constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq; + const int nblocks = nrows * (ncols / block_traits::qk); + + static_assert(blocks_per_subgroup > 0); + static_assert(block_elements_per_subgroup > 0); + + float partial_sum = 0.0f; + for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) { + const int ibx = row * blocks_per_row + i; // x block index + + const auto bx_offset = block_type::get_block_offset(ibx, nblocks); + const auto d_offset = block_type::get_d_offset(nrows, ncols, ibx); + // Y block index that aligns with ibx + const int iby = i * block_type::block_to_q8_1_ratio(); + const int8_t* q8_1_quant_ptr = (const int8_t*)vy + iby * QK8_1; + const sycl::half2* q8_1_ds_ptr = (const sycl::half2*)((const char*)vy + ncols + iby * sizeof(sycl::half2)); + +#pragma unroll + for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) { + // x block quant index when casting the quants to int + const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup); + + partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs); + } + } + + auto sum = sycl::reduce_over_group(nd_item.get_sub_group(), partial_sum, std::plus<>()); + + if (sg.leader()) { + dst[row] = sum; + } +} template static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, @@ -480,26 +534,39 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx, } } -static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy, - float *dst, const int ncols, - const int nrows, +static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, + const int nrows, dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK4_0 == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE)); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder>(vx, vy, dst, ncols, nrows, + nd_item); + }); + }); +} + +static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK4_0 == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { - - stream->submit([&](sycl::handler &cgh) { - - cgh.parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - mul_mat_vec_q( - vx, vy, dst, ncols, nrows, item_ct1); - }); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); }); } } @@ -672,6 +739,27 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy, } } +static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, + const int nrows, dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder>(vx, vy, dst, ncols, + nrows, nd_item); + }); + }); +} + + static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -696,6 +784,24 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, } } +static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, + const int nrows, dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder>(vx, vy, dst, ncols, nrows, + nd_item); + }); + }); +} static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -916,93 +1022,109 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy, } } -void ggml_sycl_op_mul_mat_vec_q( - ggml_backend_sycl_context & ctx, - const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, - const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, - float *dst_dd_i, const int64_t row_low, const int64_t row_high, - const int64_t src1_ncols, const int64_t src1_padded_col_size, - const dpct::queue_ptr &stream) { - +void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, + ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, + const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_col_size, + const dpct::queue_ptr & stream) { const int64_t ne10 = src1->ne[0]; GGML_ASSERT(ne10 % QK8_1 == 0); - const int64_t ne00 = src0->ne[0]; + const int64_t ne00 = src0->ne[0]; const int64_t row_diff = row_high - row_low; int id; - SYCL_CHECK( - CHECK_TRY_ERROR(id = get_current_device_id())); + SYCL_CHECK(CHECK_TRY_ERROR(id = get_current_device_id())); const size_t q8_1_ts = sizeof(block_q8_1); const size_t q8_1_bs = QK8_1; // the main device has a larger memory buffer to hold the results from all GPUs // nrows_dst == nrows of the matrix that the kernel writes into - for (int i = 0; i < src1_ncols; i++) - { + for (int i = 0; i < src1_ncols; i++) { const size_t src1_ddq_i_offset = i * src1_padded_col_size * q8_1_ts / q8_1_bs; - const char* src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset; - float* dst_dd_i_bs = dst_dd_i + i * dst->ne[0]; + const char * src1_ddq_i_bs = src1_ddq_i + src1_ddq_i_offset; + float * dst_dd_i_bs = dst_dd_i + i * dst->ne[0]; switch (src0->type) { - case GGML_TYPE_Q4_0: - mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q4_1: - mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_0: - mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_1: - mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q8_0: - mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q2_K: - mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q3_K: - mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q4_K: - mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_K: - mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_Q6_K: - mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ1_S: - mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ1_M: - mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ2_XXS: - mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ2_XS: - mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ2_S: - mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ3_XXS: - mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ3_S: - mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ4_NL: - mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - case GGML_TYPE_IQ4_XS: - mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - break; - default: - GGML_ABORT("fatal error"); + case GGML_TYPE_Q4_0: + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\n"); + reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } else { + GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl\n"); + mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + break; + case GGML_TYPE_Q4_1: + mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_0: + mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q5_1: + mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q8_0: + mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q2_K: + mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q3_K: + mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q4_K: + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_k_q8_1_sycl\n"); + reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } else { + GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_K_q8_1_sycl\n"); + mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + break; + case GGML_TYPE_Q5_K: + mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_Q6_K: + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n"); + reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } else { + GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_k_q8_1_sycl\n"); + mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + break; + case GGML_TYPE_IQ1_S: + mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ1_M: + mul_mat_vec_iq1_m_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ2_XXS: + mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ2_XS: + mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ2_S: + mul_mat_vec_iq2_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ3_XXS: + mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ3_S: + mul_mat_vec_iq3_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ4_NL: + mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ4_XS: + mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; + default: + GGML_ABORT("fatal error"); } } GGML_UNUSED(src1); diff --git a/ggml/src/ggml-sycl/norm.cpp b/ggml/src/ggml-sycl/norm.cpp index 4e9f438b4..4ec141684 100644 --- a/ggml/src/ggml-sycl/norm.cpp +++ b/ggml/src/ggml-sycl/norm.cpp @@ -1,40 +1,50 @@ #include "norm.hpp" +#include "ggml-sycl/common.hpp" +#include "ggml-sycl/presets.hpp" -static void norm_f32(const float* x, float* dst, const int ncols, const float eps, - const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) { - const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + - item_ct1.get_local_id(1); - const int tid = item_ct1.get_local_id(2); +static void norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, + const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) { + + const int nrows = item_ct1.get_group_range(2); + const int nchannels = item_ct1.get_group_range(1); const int nthreads = item_ct1.get_local_range(2); + const int sample = item_ct1.get_group(0); + const int channel = item_ct1.get_group(1); + const int row = item_ct1.get_group(2); + + const int tid = item_ct1.get_local_id(2); const int nwarps = nthreads / WARP_SIZE; + + const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row}); + const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row}); + + x += strided_offset; + dst += packed_offset; + sycl::float2 mean_var = sycl::float2(0.f, 0.f); for (int col = tid; col < ncols; col += block_size) { - const float xi = x[row * ncols + col]; + const float xi = x[col]; mean_var.x() += xi; mean_var.y() += xi * xi; } // sum up partial sums mean_var = warp_reduce_sum(mean_var, item_ct1); - if (block_size > WARP_SIZE) { - - int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; - int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = mean_var; + if (block_size > WARP_SIZE) { + const auto sub_group = item_ct1.get_sub_group(); + const auto sg_id = sub_group.get_group_linear_id(); + const auto wi_in_sg = sub_group.get_local_linear_id(); + if (wi_in_sg == 0) { + s_sum[sg_id] = mean_var; } - /* - DPCT1118:0: SYCL group functions and algorithms must be encountered in - converged control flow. You may need to adjust the code. - */ item_ct1.barrier(sycl::access::fence_space::local_space); mean_var = 0.f; - size_t nreduce = nwarps / WARP_SIZE; + const size_t nreduce = ceil_div(nwarps, WARP_SIZE); for (size_t i = 0; i < nreduce; i += 1) { - mean_var += s_sum[lane_id + i * WARP_SIZE]; + mean_var += s_sum[wi_in_sg + i * WARP_SIZE]; } mean_var = warp_reduce_sum(mean_var, item_ct1); } @@ -44,7 +54,7 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep const float inv_std = sycl::rsqrt(var + eps); for (int col = tid; col < ncols; col += block_size) { - dst[row * ncols + col] = (x[row * ncols + col] - mean) * inv_std; + dst[col] = (x[col] - mean) * inv_std; } } @@ -135,39 +145,51 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con } } -static void rms_norm_f32(const float* x, float* dst, const int ncols, const float eps, - const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) { - const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + - item_ct1.get_local_id(1); - const int tid = item_ct1.get_local_id(2); +static void rms_norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, + const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) { + + const int nrows = item_ct1.get_group_range(2); + const int nchannels = item_ct1.get_group_range(1); + + const int sample = item_ct1.get_group(0); + const int channel = item_ct1.get_group(1); + const int row = item_ct1.get_group(2); + const int nthreads = item_ct1.get_local_range(2); + + const int tid = item_ct1.get_local_id(2); const int nwarps = nthreads / WARP_SIZE; + + const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row}); + const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row}); + + x += strided_offset; + dst += packed_offset; + + float tmp = 0.0f; // partial sum for thread in warp for (int col = tid; col < ncols; col += block_size) { - const float xi = x[row * ncols + col]; + const float xi = x[col]; tmp += xi * xi; } // sum up partial sums tmp = warp_reduce_sum(tmp, item_ct1); if (block_size > WARP_SIZE) { - - int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; - int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = tmp; + const auto sub_group = item_ct1.get_sub_group(); + const auto sg_id = sub_group.get_group_linear_id(); + const auto wi_in_sg = sub_group.get_local_linear_id(); + if (wi_in_sg == 0) { + s_sum[sg_id] = tmp; } - /* - DPCT1118:3: SYCL group functions and algorithms must be encountered in - converged control flow. You may need to adjust the code. - */ + item_ct1.barrier(sycl::access::fence_space::local_space); - size_t nreduce = nwarps / WARP_SIZE; + const size_t nreduce = ceil_div(nwarps, WARP_SIZE); tmp = 0.f; for (size_t i = 0; i < nreduce; i += 1) { - tmp += s_sum[lane_id + i * WARP_SIZE]; + tmp += s_sum[wi_in_sg + i * WARP_SIZE]; } tmp = warp_reduce_sum(tmp, item_ct1); } @@ -176,7 +198,7 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa const float scale = sycl::rsqrt(mean + eps); for (int col = tid; col < ncols; col += block_size) { - dst[row * ncols + col] = scale * x[row * ncols + col]; + dst[col] = scale * x[col]; } } @@ -224,20 +246,20 @@ static void l2_norm_f32(const float* x, float* dst, const int ncols, const float } } -static void norm_f32_sycl(const float* x, float* dst, const int ncols, - const int nrows, const float eps, - queue_ptr stream, int device) { +static void norm_f32_sycl(const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, + const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, + const float eps, queue_ptr stream, int device) { + + const sycl::range<3> global_dims(nsamples, nchannels, nrows); GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { const sycl::range<3> block_dims(1, 1, WARP_SIZE); stream->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, - block_dims), + sycl::nd_range<3>(global_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - norm_f32(x, dst, ncols, eps, item_ct1, - nullptr, WARP_SIZE); + norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE); }); }); } @@ -252,15 +274,12 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols, */ stream->submit([&](sycl::handler& cgh) { sycl::local_accessor s_sum_acc_ct1( - sycl::range<1>(work_group_size / WARP_SIZE), cgh); - + sycl::range<1>(work_group_size / WARP_SIZE), cgh); cgh.parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, - block_dims), + sycl::nd_range<3>(global_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - norm_f32(x, dst, ncols, eps, item_ct1, - get_pointer(s_sum_acc_ct1), work_group_size); + norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size); }); }); } @@ -313,21 +332,20 @@ static void group_norm_f32_sycl(const float* x, float* dst, } } -static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, - const int nrows, const float eps, - queue_ptr stream, int device) { +static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples, + const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) { GGML_ASSERT(ncols % WARP_SIZE == 0); // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE); + + const sycl::range<3> global_dims(nsamples, nchannels, nrows); if (ncols < 1024) { const sycl::range<3> block_dims(1, 1, WARP_SIZE); stream->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, - block_dims), + sycl::nd_range<3>(global_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - rms_norm_f32(x, dst, ncols, eps, item_ct1, - nullptr, WARP_SIZE); + rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE); }); }); } @@ -344,12 +362,10 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, sycl::local_accessor s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE), cgh); cgh.parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, - block_dims), + sycl::nd_range<3>(global_dims * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - rms_norm_f32(x, dst, ncols, eps, item_ct1, - get_pointer(s_sum_acc_ct1), work_group_size); + rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size); }); }); } @@ -398,12 +414,12 @@ static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols, } void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { + const ggml_tensor * src0 = dst->src[0]; GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - const int64_t ne00 = dst->src[0]->ne[0]; - const int64_t nrows = ggml_nrows(dst->src[0]); + GGML_TENSOR_UNARY_OP_LOCALS dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); const float * src0_dd = static_cast(dst->src[0]->data); @@ -411,8 +427,14 @@ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { float eps; memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); + const size_t ts0 = ggml_type_size(src0->type); + GGML_ASSERT(nb00 == ts0); + const int64_t s01 = nb01 / ts0; + const int64_t s02 = nb02 / ts0; + const int64_t s03 = nb03 / ts0; - norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device); + norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device); } void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { @@ -436,11 +458,10 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - const int64_t ne00 = dst->src[0]->ne[0]; - const int64_t nrows = ggml_nrows(dst->src[0]); dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); @@ -450,7 +471,13 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { float eps; memcpy(&eps, dst->op_params, sizeof(float)); - rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device); + GGML_TENSOR_UNARY_OP_LOCALS + const size_t ts0 = ggml_type_size(src0->type); + GGML_ASSERT(nb00 == ts0); + const int64_t s01 = nb01 / ts0; + const int64_t s02 = nb02 / ts0; + const int64_t s03 = nb03 / ts0; + rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device); } void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { diff --git a/ggml/src/ggml-sycl/outprod.cpp b/ggml/src/ggml-sycl/outprod.cpp index b60415784..3a17f3a1b 100644 --- a/ggml/src/ggml-sycl/outprod.cpp +++ b/ggml/src/ggml-sycl/outprod.cpp @@ -1,6 +1,7 @@ #include "outprod.hpp" void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); const ggml_tensor *src0 = dst->src[0]; const ggml_tensor *src1 = dst->src[1]; diff --git a/ggml/src/ggml-sycl/quants.hpp b/ggml/src/ggml-sycl/quants.hpp new file mode 100644 index 000000000..8b952db43 --- /dev/null +++ b/ggml/src/ggml-sycl/quants.hpp @@ -0,0 +1,111 @@ +// +// MIT license +// Copyright (C) 2025 Codeplay Software Ltd. +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_QUANTS_HPP +#define GGML_SYCL_QUANTS_HPP + +#include + +#include "ggml-common.h" +#include "ggml.h" + +namespace ggml_sycl_reordered { + +// The reordered block moves quants (qs) and scales(d) to two +// uniform regions of memory that is contiguous in the same tensor. +// What this means is that instead of having: +// [d0, qs0] [d1, qs1] [d2, qs2] ... [dN, qsN] +// We have: +// [qs0, qs1, qs2, ..., qsN] [d0, d1, d2, ..., dN] +// +// Notes: out-of-bounds qs will run into d values +// Aligment relies on the allocated size of qs + +template struct block_q_t; + +// qk number of weights / quants in a block +// qr number of weights in a byte (described as 'before dequantization') +// for quantization types that has low and high bits split, qr is calculated with +// using the lower bits, e.g for Q6 quants QR6 is 2 +// qi number of 32 bit integers needed to represent all the quants from a block (`qs` field) +// See ggml-common.h to see how these are calculated +template <> struct block_q_t { + struct traits { + static constexpr uint32_t qk = QK4_0; + static constexpr uint32_t qi = QI4_0; + static constexpr uint32_t qr = QR4_0; + static constexpr uint32_t vdr_mmvq = 2; + }; + + static constexpr std::pair get_block_offset(const int block_index, const int /* nblocks */) { + return { block_index * (traits::qk / traits::qr), 0 }; + } + + static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) { + return { (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half), 0 }; + } + + static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } +}; + +template <> struct block_q_t { + struct traits { + static constexpr uint32_t qk = QK_K; + static constexpr uint32_t qi = QI4_K; + static constexpr uint32_t qr = QR4_K; + static constexpr uint32_t vdr_mmvq = 2; + }; + + static constexpr std::pair get_block_offset(const int block_index, const int /* nblocks */) { + return { block_index * (traits::qk / traits::qr), 0 }; + } + + static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) { + auto nblocks = (nrows * (ncols / traits::qk)); + return { nblocks * (QK_K / 2), + (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2)) }; + } + + static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } + + constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; } +}; + +template <> struct block_q_t { + struct traits { + static constexpr uint32_t qk = QK_K; + static constexpr uint32_t qi = QI6_K; + static constexpr uint32_t qr = QR6_K; + static constexpr uint32_t vdr_mmvq = 1; + }; + + static constexpr std::pair get_block_offset(const int block_index, const int n_blocks) { + auto low_bits_index = block_index * (traits::qk / traits::qr); + // the index of high bits it's after all low bits + auto high_bits_index = n_blocks * (QK_K / 2) + (block_index * (QK_K / 4)); + return { low_bits_index, high_bits_index }; + } + + static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) { + auto nblocks = (nrows * (ncols / traits::qk)); + auto total_qs_bytes = nblocks * (QK_K / 2) + nblocks * (QK_K / 4); + auto block_scales = total_qs_bytes + block_index * (QK_K / 16); + auto sb_scale = total_qs_bytes + nblocks * (QK_K / 16); + return { block_scales, sb_scale }; + } + + static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } +}; +} // namespace ggml_sycl_reordered + +#endif // GGML_SYCL_QUANTS_HPP diff --git a/ggml/src/ggml-sycl/rope.cpp b/ggml/src/ggml-sycl/rope.cpp index bbcb356e9..44473e1e5 100644 --- a/ggml/src/ggml-sycl/rope.cpp +++ b/ggml/src/ggml-sycl/rope.cpp @@ -1,9 +1,15 @@ #include "rope.hpp" +#include "ggml-sycl/common.hpp" +#include "ggml.h" struct rope_corr_dims { float v[2]; }; +struct mrope_sections { + int v[4]; +}; + static float rope_yarn_ramp(const float low, const float high, const int i0) { const float y = (i0 / 2 - low) / sycl::max(0.001f, high - low); return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y)); @@ -28,105 +34,200 @@ static void rope_yarn( *sin_theta = sycl::sin(theta) * mscale; } -template -static void rope_norm( - const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, - const sycl::nd_item<3> &item_ct1) { - const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + - item_ct1.get_local_id(1)); +template +static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, + const int32_t * pos, float freq_scale, float ext_factor, float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, + const sycl::nd_item<3> & item_ct1) { + const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1)); if (i0 >= ne0) { return; } - const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); + const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); if (i0 >= n_dims) { - const int i = row*ne0 + i0; - - dst[i + 0] = x[i + 0]; - dst[i + 1] = x[i + 1]; - + const int i = row * ne0 + i0; + *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i); return; } - const int i = row*ne0 + i0; - const int i2 = row/p_delta_rows; + const int row0 = row % ne1; + const int channel0 = row / ne1; - const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f); + const int i = row * ne0 + i0; + const int i2 = channel0 * s2 + row0 * s1 + i0; - const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; + const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f); + + const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; float cos_theta; float sin_theta; - rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); - const float x0 = x[i + 0]; - const float x1 = x[i + 1]; + const float x0 = x[i2 + 0]; + const float x1 = x[i2 + 1]; - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + 1] = x0*sin_theta + x1*cos_theta; + dst[i + 0] = x0 * cos_theta - x1 * sin_theta; + dst[i + 1] = x0 * sin_theta + x1 * cos_theta; } -template -static void rope_neox( - const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, - float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, - const sycl::nd_item<3> &item_ct1) { - const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + - item_ct1.get_local_id(1)); +template +static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, + const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, + const sycl::nd_item<3> & item_ct1) { + const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1)); if (i0 >= ne0) { return; } - const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + - item_ct1.get_local_id(2); + const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); if (i0 >= n_dims) { - const int i = row*ne0 + i0; - - dst[i + 0] = x[i + 0]; - dst[i + 1] = x[i + 1]; - + const int i = row * ne0 + i0; + *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i); return; } - const int i = row*ne0 + i0/2; - const int i2 = row/p_delta_rows; + const int row0 = row % ne1; + const int channel0 = row / ne1; - const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f); + const int i = row * ne0 + i0 / 2; + const int i2 = channel0 * s2 + row0 * s1 + i0 / 2; - const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; + const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f); + + const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; float cos_theta; float sin_theta; - rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); - const float x0 = x[i + 0]; - const float x1 = x[i + n_dims/2]; + const float x0 = x[i2 + 0]; + const float x1 = x[i2 + n_dims / 2]; - dst[i + 0] = x0*cos_theta - x1*sin_theta; - dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta; + dst[i + 0] = x0 * cos_theta - x1 * sin_theta; + dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta; +} + +template +static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1, + const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale, + const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, + const float theta_scale, const float * freq_factors, const mrope_sections sections, + const sycl::nd_item<3> & item_ct1) { + // get index pos + const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1)); + if (i0 >= ne0) { + return; + } + const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2); + + if (i0 >= n_dims) { + const int i = row_dst*ne0 + i0; + *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i); + return; + } + + const int row_x = row_dst % ne1; + const int channel_x = row_dst / ne1; + const int idst = (row_dst * ne0) + (i0 / 2); + const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2); + + const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3]; + const int sec_w = sections.v[1] + sections.v[0]; + const int sector = (i0 / 2) % sect_dims; + + + float theta_base = 0.0; + if (sector < sections.v[0]) { + theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f); + } + else if (sector >= sections.v[0] && sector < sec_w) { + theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f); + } + else if (sector >= sec_w && sector < sec_w + sections.v[2]) { + theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f); + } + else if (sector >= sec_w + sections.v[2]) { + theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f); + } + + const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; + float cos_theta; + float sin_theta; + rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + const float x0 = x[ix + 0]; + const float x1 = x[ix + n_dims/2]; + + // store results in dst + dst[idst + 0] = x0 * cos_theta - x1 * sin_theta; + dst[idst + n_dims/2] = x0 * sin_theta + x1 * cos_theta; +} + + + +template +static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1, + const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale, + const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, + const float theta_scale, const float * freq_factors, const mrope_sections sections, + const sycl::nd_item<3> & item_ct1) { + // get index pos + const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1)); + if (i0 >= ne0) { + return; + } + const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2); + const int row_x = row_dst % ne1; + const int channel_x = row_dst / ne1; + const int idst = (row_dst * ne0) + (i0 / 2); + const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2); + + const int sect_dims = sections.v[0] + sections.v[1]; + const int sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0f; + if (sector < sections.v[0]) { + const int p = sector; + theta_base = pos[channel_x] * sycl::pow(theta_scale, (float) p); + } else { + // Simplified from CUDA backend code: if (sector >= sections.v[0] && sector < sec_w) which is just sector >= sections.v[0] + const int p = sector - sections.v[0]; + theta_base = pos[channel_x + ne2] * sycl::pow(theta_scale, (float) p); + } + + const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; + float cos_theta; + float sin_theta; + rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + const float x0 = x[ix + 0]; + const float x1 = x[ix + n_dims]; + + // store results in dst + dst[idst + 0] = x0 * cos_theta - x1 * sin_theta; + dst[idst + n_dims] = x0 * sin_theta + x1 * cos_theta; } template -static void rope_norm_sycl( - const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) { +static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, + const int n_dims, int nr, const int32_t * pos, const float freq_scale, const float freq_base, + const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, + const float * freq_factors, queue_ptr stream) { GGML_ASSERT(ne0 % 2 == 0); const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); - const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE); + const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE)); const sycl::range<3> block_nums(1, num_blocks_x, nr); - const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float theta_scale = powf(freq_base, -2.0f / n_dims); - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); if (freq_factors == nullptr) { /* @@ -134,79 +235,136 @@ static void rope_norm_sycl( the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. */ - stream->parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - rope_norm(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, - ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, - item_ct1); - }); + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + rope_norm(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, item_ct1); + }); } else { /* DPCT1049:41: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. */ - stream->parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - rope_norm(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, - ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, - item_ct1); - }); + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + rope_norm(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, item_ct1); + }); } } template -static void rope_neox_sycl( - const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows, - float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) { +static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, + const int n_dims, const int nr, const int32_t * pos, const float freq_scale, + const float freq_base, const float ext_factor, const float attn_factor, + const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) { GGML_ASSERT(ne0 % 2 == 0); const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); - const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE); + const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE)); const sycl::range<3> block_nums(1, num_blocks_x, nr); - const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float theta_scale = powf(freq_base, -2.0f / n_dims); - dpct::has_capability_or_fail(stream->get_device(), - {sycl::aspect::fp16}); + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); if (freq_factors == nullptr) { - stream->parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - rope_neox(x, dst, ne0, n_dims, pos, freq_scale, - p_delta_rows, ext_factor, attn_factor, - corr_dims, theta_scale, freq_factors, - item_ct1); - }); + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + rope_neox(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, item_ct1); + }); } else { - stream->parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - rope_neox(x, dst, ne0, n_dims, pos, freq_scale, - p_delta_rows, ext_factor, attn_factor, - corr_dims, theta_scale, freq_factors, - item_ct1); - }); + stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { + rope_neox(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, item_ct1); + }); } } -void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { +template +static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1, + const size_t s2, const int n_dims, const int nr, const int32_t * pos, + const float freq_scale, const float freq_base, const float ext_factor, + const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors, + const mrope_sections sections, queue_ptr stream) { + GGML_ASSERT(ne0 % 2 == 0); + const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); + const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE)); + const sycl::range<3> grid_dims(1, n_blocks_y, nr); + const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims); + + const float theta_scale = std::pow(freq_base, -2.0f / n_dims); + // Add FP16 capability check if T could be sycl::half + if constexpr (std::is_same_v) { + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + } + // launch kernel + if (freq_factors == nullptr) { + stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { + rope_multi(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, + corr_dims, theta_scale, freq_factors, sections, item_ct1); + }); + } else { + stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { + rope_multi(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, + corr_dims, theta_scale, freq_factors, sections, item_ct1); + }); + } +} + + + + +// rope vision +template +static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1, + const size_t s2, const int n_dims, const int nr, const int32_t * pos, + const float freq_scale, const float freq_base, const float ext_factor, + const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors, + const mrope_sections sections, queue_ptr stream) { + GGML_ASSERT(ne0 % 2 == 0); + const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); + const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE)); + const sycl::range<3> grid_dims(1, n_blocks_y, nr); + const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims); + + const float theta_scale = std::pow(freq_base, -2.0f / n_dims); + // Add FP16 capability check if T could be sycl::half + if constexpr (std::is_same_v) { + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + } + // launch kernel + if (freq_factors == nullptr) { + stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { + rope_vision(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, + corr_dims, theta_scale, freq_factors, sections, item_ct1); + }); + } else { + stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { + rope_vision(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, + corr_dims, theta_scale, freq_factors, sections, item_ct1); + }); + } +} + +inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT(dst->src[0]->type == dst->type); - - const int64_t ne00 = dst->src[0]->ne[0]; - const int64_t ne01 = dst->src[0]->ne[1]; + const int64_t ne00 = dst->src[0]->ne[0]; // head dims + const int64_t ne01 = dst->src[0]->ne[1]; // num heads + const int64_t ne02 = dst->src[0]->ne[2]; // num heads const int64_t nr = ggml_nrows(dst->src[0]); + const size_t s01 = dst->src[0]->nb[1] / ggml_type_size(dst->src[0]->type); + const size_t s02 = dst->src[0]->nb[2] / ggml_type_size(dst->src[0]->type); + + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; //const int n_ctx = ((int32_t *) dst->op_params)[3]; const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + mrope_sections sections; // RoPE alteration for extended context float freq_base; @@ -222,8 +380,19 @@ void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(§ions.v, (int32_t *) dst->op_params + 11, sizeof(int)*4); const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + if (is_mrope) { + GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0); + } + + if (is_vision) { + GGML_ASSERT(n_dims == ne00/2); + } const int32_t * pos = (const int32_t *) dst->src[1]->data; @@ -240,32 +409,60 @@ void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { // compute if (is_neox) { + GGML_SYCL_DEBUG("%s: neox path\n", __func__); if (dst->src[0]->type == GGML_TYPE_F32) { - rope_neox_sycl( - (const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, main_stream - ); + rope_neox_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr, + pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream); } else if (dst->src[0]->type == GGML_TYPE_F16) { - rope_neox_sycl( - (const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, main_stream - ); + rope_neox_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02, + n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, + main_stream); } else { GGML_ABORT("fatal error"); } + } else if (is_mrope && !is_vision) { + GGML_SYCL_DEBUG("%s: mrope path\n", __func__); + if (dst->src[0]->type == GGML_TYPE_F16) { + rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01, + s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, + freq_factors, sections, main_stream); + } else if (dst->src[0]->type == GGML_TYPE_F32) { + rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims, + nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, + main_stream); + } else { + GGML_ABORT("Fatal error: Tensor type unsupported!"); + } + } else if (is_vision) { + GGML_SYCL_DEBUG("%s: vision path\n", __func__); + if (dst->src[0]->type == GGML_TYPE_F16) { + rope_vision_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, ne02, s01, + s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, + freq_factors, sections, main_stream); + } else if (dst->src[0]->type == GGML_TYPE_F32) { + rope_vision_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims, + nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, + main_stream); + } else { + GGML_ABORT("Fatal error: Tensor type unsupported!"); + } } else { + GGML_SYCL_DEBUG("%s: norm path\n", __func__); if (dst->src[0]->type == GGML_TYPE_F32) { - rope_norm_sycl( - (const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, main_stream - ); + rope_norm_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr, + pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream); } else if (dst->src[0]->type == GGML_TYPE_F16) { - rope_norm_sycl( - (const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor, - attn_factor, corr_dims, freq_factors, main_stream - ); + rope_norm_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02, + n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, + main_stream); } else { GGML_ABORT("fatal error"); } } } + +void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3); + ggml_sycl_op_rope(ctx, dst); +} + diff --git a/ggml/src/ggml-sycl/rope.hpp b/ggml/src/ggml-sycl/rope.hpp index a399bddb8..8c7141aac 100644 --- a/ggml/src/ggml-sycl/rope.hpp +++ b/ggml/src/ggml-sycl/rope.hpp @@ -15,6 +15,6 @@ #include "common.hpp" -void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst); +void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst); #endif // GGML_SYCL_ROPE_HPP diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp index 7563d9ced..52fcf4b3d 100644 --- a/ggml/src/ggml-sycl/softmax.cpp +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -225,7 +225,7 @@ static void soft_max_f32_sycl(const float * x, const T * mask, } void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -249,16 +249,13 @@ void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F16) { const sycl::half * src1_dd = static_cast(dst->src[1]->data); - GGML_SYCL_DEBUG("%s: F16 mask\n", __func__); soft_max_f32_sycl(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); } else if (dst->src[1] && dst->src[1]->type == GGML_TYPE_F32) { const float * src1_dd = static_cast(dst->src[1]->data); - GGML_SYCL_DEBUG("%s: F32 mask\n", __func__); soft_max_f32_sycl(src0_dd, src1_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); } else { /* mask unavailable */ - GGML_SYCL_DEBUG("%s: No mask\n", __func__); soft_max_f32_sycl(src0_dd, nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); } } diff --git a/ggml/src/ggml-sycl/tsembd.cpp b/ggml/src/ggml-sycl/tsembd.cpp index b877d18c1..f6ca626ea 100644 --- a/ggml/src/ggml-sycl/tsembd.cpp +++ b/ggml/src/ggml-sycl/tsembd.cpp @@ -56,8 +56,8 @@ static void timestep_embedding_f32_sycl( } void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - const ggml_tensor *src0 = dst->src[0]; - const ggml_tensor *src1 = dst->src[1]; + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; float * dst_d = (float *)dst->data; dpct::queue_ptr stream = ctx.stream(); @@ -69,5 +69,4 @@ void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, ggml_tenso const int max_period = dst->op_params[1]; timestep_embedding_f32_sycl(src0_d, dst_d, src0->ne[0], dst->nb[1], dim, max_period, stream); - GGML_UNUSED(src1); } diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index c5942008a..0a5d49994 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -1,6 +1,6 @@ // // MIT license -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: MIT // @@ -14,8 +14,11 @@ #define GGML_SYCL_VECDOTQ_HPP #include "dpct/helper.hpp" +#include "ggml.h" +#include "quants.hpp" -typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs); +typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, + const int & iqs); static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) { const uint16_t* x16 = @@ -252,13 +255,213 @@ vec_dot_q6_K_q8_1_impl_mmvq(const int &vl, const int &vh, // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q +template struct reorder_vec_dot_q_sycl { + static_assert(T != T, "ggml_type for reorder vecdot not implemented"); +}; + +template <> struct reorder_vec_dot_q_sycl { + static constexpr ggml_type gtype = GGML_TYPE_Q4_0; + + using q4_0_block = ggml_sycl_reordered::block_q_t; + using q4_0_traits = typename q4_0_block::traits; + + __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int * v, const int * u, const float & d4, const sycl::half2 & ds8) { + int sumi = 0; + +#pragma unroll + for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) { + const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; + const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; + + // SIMD dot product of quantized values + sumi = dpct::dp4a(vi0, u[2 * i + 0], sumi); + sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi); + } + + const sycl::float2 ds8f = ds8.convert(); + + // second part effectively subtracts 8 from each quant value + return d4 * (sumi * ds8f.x() - (8 * q4_0_traits::vdr_mmvq / q4_0_traits::qi) * ds8f.y()); + } + + __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair ibx_offset, + const std::pair d_offset, const int8_t * q8_1_quant_ptr, + const sycl::half2 * q8_1_ds, const int & iqs) { + const uint8_t * bq4_0 = static_cast(vbq) + ibx_offset.first; + const ggml_half d = *(reinterpret_cast(static_cast(vbq) + d_offset.first)); + int v[q4_0_traits::vdr_mmvq]; + int u[2 * q4_0_traits::vdr_mmvq]; + + +#pragma unroll + for (size_t i = 0; i < q4_0_traits::vdr_mmvq; ++i) { + v[i] = get_int_from_uint8(bq4_0, iqs + i); + u[2 * i + 0] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i); + u[2 * i + 1] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i + q4_0_traits::qi); + } + + return vec_dot_q4_0_q8_1_impl(v, u, d, *q8_1_ds); + }; +}; + +static inline float vec_dot_q4_K_q8_1_common(const int * __restrict__ q4, const uint16_t * __restrict__ scales, + const ggml_half2 & dm, const block_q8_1 * __restrict__ bq8_1, + const int & iqs) { + int v[2]; + int u[2 * QR4_K]; + float d8[QR4_K]; + + v[0] = q4[0]; + v[1] = q4[4]; + + uint16_t aux[2]; + const int j = (QR4_K * ((iqs / 2) / (QI8_1 / 2))) / 2; + if (j < 2) { + aux[0] = scales[j + 0] & 0x3f3f; + aux[1] = scales[j + 2] & 0x3f3f; + } else { + aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2); + aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2); + } + + const uint8_t * sc = (const uint8_t *) aux; + const uint8_t * m = sc + 2; + + const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2)); + + for (int i = 0; i < QR4_K; ++i) { + const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; + d8[i] = bq8i->ds[0]; + + const int * q8 = (const int *) bq8i->qs + ((iqs / 2) % 4); + u[2 * i + 0] = q8[0]; + u[2 * i + 1] = q8[4]; + } + + return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, dm, d8); +} + +template <> struct reorder_vec_dot_q_sycl { + static constexpr ggml_type gtype = GGML_TYPE_Q4_K; + + using q4_k_block = ggml_sycl_reordered::block_q_t; + using q4_k_traits = typename q4_k_block::traits; + + __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair ibx_offset, + const std::pair d_offset, const int8_t * q8_1_quant_ptr, + const sycl::half2 * q8_1_ds, const int & iqs) { + const int ib = ibx_offset.first / (QK_K / 2); + + const uint8_t * base = static_cast(vbq); + const uint8_t * qs = base + ibx_offset.first; + const uint8_t * scs = base + d_offset.first + ib * K_SCALE_SIZE; + const ggml_half2 * dms = reinterpret_cast(base + d_offset.second); + + const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2)); + const int * q4 = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4)); + const uint16_t * scales = (const uint16_t *) scs; + + int v[2]; + int u[2 * QR4_K]; + float d8[QR4_K]; + + v[0] = q4[0]; + v[1] = q4[4]; + + uint16_t aux[2]; + const int j = (QR4_K * ((iqs / 2) / (QI8_1 / 2))) / 2; + if (j < 2) { + aux[0] = scales[j + 0] & 0x3f3f; + aux[1] = scales[j + 2] & 0x3f3f; + } else { + aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2); + aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2); + } + + const uint8_t * sc = (const uint8_t *) aux; + const uint8_t * m = sc + 2; + + for (int i = 0; i < QR4_K; ++i) { + const int8_t* quant_base_ptr = q8_1_quant_ptr + (bq8_offset + i) * QK8_1; + sycl::half2 ds_values = *(q8_1_ds + bq8_offset + i); + + d8[i] = ds_values[0]; + + const int * q8 = (const int *) quant_base_ptr + ((iqs / 2) % 4); + u[2 * i + 0] = q8[0]; + u[2 * i + 1] = q8[4]; + } + + return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, *dms, d8); + } +}; + +template <> struct reorder_vec_dot_q_sycl { + static constexpr ggml_type gtype = GGML_TYPE_Q6_K; + + using q6_k_block = ggml_sycl_reordered::block_q_t; + using q6_k_traits = typename q6_k_block::traits; + + __dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq(const int vl, const int vh, const int * __restrict__ u, + const int8_t * __restrict__ scales, const float d, + const float * __restrict__ d8) { + float sumf = 0.0f; + +#pragma unroll + for (int i = 0; i < QR6_K; ++i) { + const int sc = scales[4 * i]; + + const int vil = (vl >> (4 * i)) & 0x0F0F0F0F; + + const int vih = ((vh >> (4 * i)) << 4) & 0x30303030; + + const int vi = dpct::vectorized_binary((vil | vih), 0x20202020, + dpct::sub_sat()); // vi = (vil | vih) - 32 + + sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product + } + + return d * sumf; + } + + __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair ibx_offset, + const std::pair d_offset, const int8_t * q8_1_quant_ptr, const sycl::half2 * q8_1_ds, + const int iqs) { + const int ib = ibx_offset.first / (QK_K / 2); + + const uint8_t * base = static_cast(vbq); + const uint8_t * ql = base + ibx_offset.first; + const uint8_t * qh = base + ibx_offset.second; + const int8_t * scales = reinterpret_cast(base + d_offset.first); + const ggml_half * d = (const ggml_half *) (base + d_offset.second) + ib; + + const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 4); + const int scale_offset = (QI6_K / 4) * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 8); + const int vh_shift = 2 * ((iqs % (QI6_K / 2)) / (QI6_K / 4)); + + const int vl = get_int_from_uint8(ql, iqs); + const int vh = get_int_from_uint8(qh, (QI6_K / 4) * (iqs / (QI6_K / 2)) + iqs % (QI6_K / 4)) >> vh_shift; + + const int8_t * scs = scales + scale_offset; + + int u[QR6_K]; + float d8[QR6_K]; + +#pragma unroll + for (int i = 0; i < QR6_K; ++i) { + u[i] = get_int_from_int8_aligned(q8_1_quant_ptr + (bq8_offset + 2 * i) * QK8_1, iqs % QI8_1); + const sycl::half2 ds_values = *(q8_1_ds + bq8_offset + 2 * i); + d8[i] = ds_values[0]; + } + return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scs, *d, d8); + } +}; #define VDR_Q4_0_Q8_1_MMVQ 2 #define VDR_Q4_0_Q8_1_MMQ 4 template -static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int *v, const int *u, - const float &d4, - const sycl::half2 &ds8) { +static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int * v, const int * u, const float & d4, + const sycl::half2 & ds8) { int sumi = 0; #pragma unroll for (int i = 0; i < vdr; ++i) { @@ -270,8 +473,7 @@ static __dpct_inline__ float vec_dot_q4_0_q8_1_impl(const int *v, const int *u, sumi = dpct::dp4a(vi1, u[2 * i + 1], sumi); } - const sycl::float2 ds8f = - ds8.convert(); + const sycl::float2 ds8f = ds8.convert(); // second part effectively subtracts 8 from each quant value return d4 * (sumi * ds8f.x() - (8 * vdr / QI4_0) * ds8f.y()); @@ -456,13 +658,13 @@ vec_dot_q4_0_q8_1(const void *__restrict__ vbq, const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq; int v[VDR_Q4_0_Q8_1_MMVQ]; - int u[2*VDR_Q4_0_Q8_1_MMVQ]; + int u[2 * VDR_Q4_0_Q8_1_MMVQ]; #pragma unroll for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) { - v[i] = get_int_from_uint8(bq4_0->qs, iqs + i); - u[2*i+0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); - u[2*i+1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0); + v[i] = get_int_from_uint8(bq4_0->qs, iqs + i); + u[2 * i + 0] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); + u[2 * i + 1] = get_int_from_int8_aligned(bq8_1->qs, iqs + i + QI4_0); } return vec_dot_q4_0_q8_1_impl(v, u, bq4_0->d, bq8_1->ds); @@ -600,52 +802,17 @@ vec_dot_q3_K_q8_1(const void *__restrict__ vbq, return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8); } -static __dpct_inline__ float -vec_dot_q4_K_q8_1(const void *__restrict__ vbq, - const block_q8_1 *__restrict__ bq8_1, const int &iqs) { - +static __dpct_inline__ float vec_dot_q4_K_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, + const int & iqs) { #ifndef GGML_QKK_64 + const block_q4_K * bq4_K = (const block_q4_K *) vbq; - int v[2]; - int u[2*QR4_K]; - float d8[QR4_K]; + const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2)); + const int * q4 = (const int *) (bq4_K->qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4)); + const uint16_t * scales = (const uint16_t *) bq4_K->scales; - // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6 - const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2)); - - // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12 - // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44 - // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76 - // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108 - - const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4)); - v[0] = q4[0]; - v[1] = q4[4]; - - const uint16_t * scales = (const uint16_t *)bq4_K->scales; - uint16_t aux[2]; - const int j = bq8_offset/2; - if (j < 2) { - aux[0] = scales[j+0] & 0x3f3f; - aux[1] = scales[j+2] & 0x3f3f; - } else { - aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); - aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); - } - const uint8_t * sc = (const uint8_t *)aux; - const uint8_t * m = sc + 2; - - for (int i = 0; i < QR4_K; ++i) { - const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; - d8[i] = bq8i->ds[0]; - - const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); - u[2*i+0] = q8[0]; - u[2*i+1] = q8[4]; - } - - return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8); + return vec_dot_q4_K_q8_1_common(q4, scales, bq4_K->dm, bq8_1, iqs); #else diff --git a/ggml/src/ggml-sycl/wkv.cpp b/ggml/src/ggml-sycl/wkv.cpp index 540f6fbf5..c10e2f764 100644 --- a/ggml/src/ggml-sycl/wkv.cpp +++ b/ggml/src/ggml-sycl/wkv.cpp @@ -180,10 +180,7 @@ static void rwkv_wkv7_f32_kernel( } void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { - - const ggml_tensor *src0 = dst->src[0]; - const ggml_tensor *src1 = dst->src[1]; - + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/6); const float* k_d = (const float*)dst->src[0]->data; const float* v_d = (const float*)dst->src[1]->data; const float* r_d = (const float*)dst->src[2]->data; @@ -236,16 +233,10 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { }); }); } - - GGML_UNUSED(src0); - GGML_UNUSED(src1); } void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { - - const ggml_tensor *src0 = dst->src[0]; - const ggml_tensor *src1 = dst->src[1]; - + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/7); const float* r_d = (const float*)dst->src[0]->data; const float* w_d = (const float*)dst->src[1]->data; const float* k_d = (const float*)dst->src[2]->data; @@ -299,7 +290,4 @@ void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { }); }); } - - GGML_UNUSED(src0); - GGML_UNUSED(src1); } diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 9d028f718..39f022f33 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -15,6 +15,32 @@ function(detect_host_compiler) set(HOST_CXX_COMPILER "${HOST_CXX_COMPILER}" PARENT_SCOPE) endfunction() +# Function to test shader extension support +# Parameters: +# EXTENSION_NAME - Name of the extension to test (e.g., "GL_EXT_integer_dot_product") +# TEST_SHADER_FILE - Path to the test shader file +# RESULT_VARIABLE - Name of the variable to set (ON/OFF) based on test result +function(test_shader_extension_support EXTENSION_NAME TEST_SHADER_FILE RESULT_VARIABLE) + execute_process( + COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${TEST_SHADER_FILE}" + OUTPUT_VARIABLE glslc_output + ERROR_VARIABLE glslc_error + ) + + if (${glslc_error} MATCHES ".*extension not supported: ${EXTENSION_NAME}.*") + message(STATUS "${EXTENSION_NAME} not supported by glslc") + set(${RESULT_VARIABLE} OFF PARENT_SCOPE) + else() + message(STATUS "${EXTENSION_NAME} supported by glslc") + set(${RESULT_VARIABLE} ON PARENT_SCOPE) + add_compile_definitions(${RESULT_VARIABLE}) + + # Ensure the extension support is forwarded to vulkan-shaders-gen + list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -D${RESULT_VARIABLE}=ON) + set(VULKAN_SHADER_GEN_CMAKE_ARGS "${VULKAN_SHADER_GEN_CMAKE_ARGS}" PARENT_SCOPE) + endif() +endfunction() + if (Vulkan_FOUND) message(STATUS "Vulkan found") @@ -23,53 +49,32 @@ if (Vulkan_FOUND) ../../include/ggml-vulkan.h ) - # Compile a test shader to determine whether GL_KHR_cooperative_matrix is supported. - # If it's not, there will be an error to stderr. - # If it's supported, set a define to indicate that we should compile those shaders - execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp" - OUTPUT_VARIABLE glslc_output - ERROR_VARIABLE glslc_error) + set(VULKAN_SHADER_GEN_CMAKE_ARGS "") - if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*") - message(STATUS "GL_KHR_cooperative_matrix not supported by glslc") - set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT OFF) - else() - message(STATUS "GL_KHR_cooperative_matrix supported by glslc") - set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT ON) - add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) - endif() + # Test all shader extensions + test_shader_extension_support( + "GL_KHR_cooperative_matrix" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp" + "GGML_VULKAN_COOPMAT_GLSLC_SUPPORT" + ) - # Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported. - # If it's not, there will be an error to stderr. - # If it's supported, set a define to indicate that we should compile those shaders - execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp" - OUTPUT_VARIABLE glslc_output - ERROR_VARIABLE glslc_error) + test_shader_extension_support( + "GL_NV_cooperative_matrix2" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp" + "GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT" + ) - if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*") - message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc") - set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT OFF) - else() - message(STATUS "GL_NV_cooperative_matrix2 supported by glslc") - set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT ON) - add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) - endif() + test_shader_extension_support( + "GL_EXT_integer_dot_product" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp" + "GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT" + ) - # Compile a test shader to determine whether GL_EXT_integer_dot_product is supported. - # If it's not, there will be an error to stderr. - # If it's supported, set a define to indicate that we should compile those shaders - execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp" - OUTPUT_VARIABLE glslc_output - ERROR_VARIABLE glslc_error) - - if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_integer_dot_product.*") - message(STATUS "GL_EXT_integer_dot_product not supported by glslc") - set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT OFF) - else() - message(STATUS "GL_EXT_integer_dot_product supported by glslc") - set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT ON) - add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - endif() + test_shader_extension_support( + "GL_EXT_bfloat16" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp" + "GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT" + ) target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan) target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) @@ -96,10 +101,6 @@ if (Vulkan_FOUND) add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO) endif() - if (GGML_VULKAN_PERF) - add_compile_definitions(GGML_VULKAN_PERF) - endif() - if (GGML_VULKAN_VALIDATE) add_compile_definitions(GGML_VULKAN_VALIDATE) endif() @@ -108,16 +109,8 @@ if (Vulkan_FOUND) add_compile_definitions(GGML_VULKAN_RUN_TESTS) endif() - if (NOT CMAKE_CROSSCOMPILING) - add_subdirectory(vulkan-shaders) - if (MSVC) - foreach(CONFIG ${CMAKE_CONFIGURATION_TYPES}) - string(TOUPPER ${CONFIG} CONFIG) - set_target_properties(vulkan-shaders-gen PROPERTIES - RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) - endforeach() - endif() - else() + # Set up toolchain for host compilation whether cross-compiling or not + if (CMAKE_CROSSCOMPILING) if (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN) set(HOST_CMAKE_TOOLCHAIN_FILE ${GGML_VULKAN_SHADERS_GEN_TOOLCHAIN}) else() @@ -130,41 +123,50 @@ if (Vulkan_FOUND) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/host-toolchain.cmake.in ${CMAKE_BINARY_DIR}/host-toolchain.cmake @ONLY) set(HOST_CMAKE_TOOLCHAIN_FILE ${CMAKE_BINARY_DIR}/host-toolchain.cmake) endif() - message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}") - - include(ExternalProject) - # Native build through ExternalProject_Add - ExternalProject_Add( - vulkan-shaders-gen - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders - CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE} - -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR} - -DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT} - -DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT} - -DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT=${GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT} - BUILD_COMMAND ${CMAKE_COMMAND} --build . - INSTALL_COMMAND ${CMAKE_COMMAND} --install . - INSTALL_DIR ${CMAKE_BINARY_DIR} - ) - ExternalProject_Add_StepTargets(vulkan-shaders-gen build install) + else() + # For non-cross-compiling, use empty toolchain (use host compiler) + set(HOST_CMAKE_TOOLCHAIN_FILE "") endif() - set (_ggml_vk_host_suffix $,.exe,>) - set (_ggml_vk_genshaders_cmd ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/vulkan-shaders-gen${_ggml_vk_host_suffix}) - set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp) - set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp) - set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders) - set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv) - file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp") - set (_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen) + include(ExternalProject) if (CMAKE_CROSSCOMPILING) - set(_ggml_vk_shader_deps ${_ggml_vk_shader_deps} vulkan-shaders-gen-build vulkan-shaders-gen-install) + list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE}) + message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}") endif() + ExternalProject_Add( + vulkan-shaders-gen + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}/$ + -DCMAKE_INSTALL_BINDIR=. + -DCMAKE_BUILD_TYPE=$ + ${VULKAN_SHADER_GEN_CMAKE_ARGS} + + BUILD_COMMAND ${CMAKE_COMMAND} --build . --config $ + + # NOTE: When DESTDIR is set using Makefile generators and + # "make install" triggers the build step, vulkan-shaders-gen + # would be installed into the DESTDIR prefix, so it is unset + # to ensure that does not happen. + + INSTALL_COMMAND ${CMAKE_COMMAND} -E env --unset=DESTDIR + ${CMAKE_COMMAND} --install . --config $ + ) + + set (_ggml_vk_host_suffix $,.exe,>) + set (_ggml_vk_genshaders_dir "${CMAKE_BINARY_DIR}/$") + set (_ggml_vk_genshaders_cmd "${_ggml_vk_genshaders_dir}/vulkan-shaders-gen${_ggml_vk_host_suffix}") + set (_ggml_vk_header "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp") + set (_ggml_vk_source "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp") + set (_ggml_vk_input_dir "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders") + set (_ggml_vk_output_dir "${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv") + + file(GLOB _ggml_vk_shader_files CONFIGURE_DEPENDS "${_ggml_vk_input_dir}/*.comp") + add_custom_command( OUTPUT ${_ggml_vk_header} - ${_ggml_vk_source} + ${_ggml_vk_source} COMMAND ${_ggml_vk_genshaders_cmd} --glslc ${Vulkan_GLSLC_EXECUTABLE} @@ -174,7 +176,9 @@ if (Vulkan_FOUND) --target-cpp ${_ggml_vk_source} --no-clean - DEPENDS ${_ggml_vk_shader_deps} + DEPENDS ${_ggml_vk_shader_files} + vulkan-shaders-gen + COMMENT "Generate vulkan shaders" ) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 783a0ff86..8d62303aa 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1,6 +1,6 @@ #include "ggml-vulkan.h" #include -#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF) || defined(GGML_VULKAN_CHECK_RESULTS) +#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_CHECK_RESULTS) #include #include "ggml-cpu.h" #endif @@ -51,6 +51,24 @@ #include "ggml-vulkan-shaders.hpp" +// remove this once it's more widely available in the SDK +#if !defined(VK_KHR_shader_bfloat16) + +#define VK_KHR_shader_bfloat16 1 +#define VK_KHR_SHADER_BFLOAT16_SPEC_VERSION 1 +#define VK_KHR_SHADER_BFLOAT16_EXTENSION_NAME "VK_KHR_shader_bfloat16" +#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR ((VkStructureType)1000141000) +#define VK_COMPONENT_TYPE_BFLOAT16_KHR ((VkComponentTypeKHR)1000141000) + +typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR { + VkStructureType sType; + void* pNext; + VkBool32 shaderBFloat16Type; + VkBool32 shaderBFloat16DotProduct; + VkBool32 shaderBFloat16CooperativeMatrix; +} VkPhysicalDeviceShaderBfloat16FeaturesKHR; +#endif + #define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1)) #define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } @@ -60,7 +78,7 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } #define VK_VENDOR_ID_INTEL 0x8086 #define VK_VENDOR_ID_NVIDIA 0x10de -#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 32 +#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256 #define GGML_VK_MAX_NODES 8192 @@ -84,25 +102,11 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } struct ggml_backend_vk_context; -struct vk_queue { - uint32_t queue_family_index; - vk::Queue queue; - vk::CommandPool pool; - uint32_t cmd_buffer_idx; - std::vector cmd_buffers; - - vk::PipelineStageFlags stage_flags; - - bool transfer_only; -}; +#define MAX_PARAMETER_COUNT 8 struct vk_pipeline_struct { std::string name; vk::ShaderModule shader_module; - vk::DescriptorSetLayout dsl; - std::vector descriptor_pools; - std::vector descriptor_sets; - uint32_t descriptor_set_idx; vk::PipelineLayout layout; vk::Pipeline pipeline; uint32_t push_constant_size; @@ -149,6 +153,45 @@ struct ggml_backend_vk_buffer_type_context { vk_device device; }; +struct vk_queue; + +// Stores command pool/buffers. There's an instance of this +// for each (context,queue) pair and for each (device,queue) pair. +struct vk_command_pool { + void init(vk_device& device, vk_queue *q_); + void destroy(vk::Device& device); + + vk::CommandPool pool; + uint32_t cmd_buffer_idx; + std::vector cmd_buffers; + + vk_queue *q; +}; + +// Prevent simultaneous submissions to the same queue. +// This could be per vk_queue if we stopped having two vk_queue structures +// sharing the same vk::Queue. +static std::mutex queue_mutex; + +struct vk_queue { + uint32_t queue_family_index; + vk::Queue queue; + + vk_command_pool cmd_pool; + + vk::PipelineStageFlags stage_flags; + + bool transfer_only; + + // copy everything except the cmd_pool + void copyFrom(vk_queue &other) { + queue_family_index = other.queue_family_index; + queue = other.queue; + stage_flags = other.stage_flags; + transfer_only = other.transfer_only; + } +}; + static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft); static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size); static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft); @@ -166,9 +209,7 @@ static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = { #ifdef GGML_VULKAN_MEMORY_DEBUG class vk_memory_logger; #endif -#ifdef GGML_VULKAN_PERF class vk_perf_logger; -#endif static void ggml_vk_destroy_buffer(vk_buffer& buf); static constexpr uint32_t mul_mat_vec_max_cols = 8; @@ -180,6 +221,7 @@ enum vk_device_architecture { AMD_RDNA1, AMD_RDNA2, AMD_RDNA3, + INTEL_XE2, }; static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) { @@ -230,6 +272,34 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& } return vk_device_architecture::AMD_RDNA2; } + } else if (props.vendorID == VK_VENDOR_ID_INTEL) { + const std::vector ext_props = device.enumerateDeviceExtensionProperties(); + + bool subgroup_size_control = false; + + for (const auto& properties : ext_props) { + if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) { + subgroup_size_control = true; + } + } + + if (!subgroup_size_control) { + return vk_device_architecture::OTHER; + } + + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; + + props2.pNext = &subgroup_size_control_props; + device.getProperties2(&props2); + + if (subgroup_size_control_props.minSubgroupSize == 16) { + // Xe2 architecture uses SIMD16 while previous Xe and Gen architecture uses SIMD8. + // Minimum subgroup size matches the SIMD width so we distinguish architecture by checking this value. + // https://www.intel.com/content/www/us/en/content-details/824434/2024-intel-tech-tour-xe2-and-lunar-lake-s-gpu.html + // https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html + return vk_device_architecture::INTEL_XE2; + } } return vk_device_architecture::OTHER; } @@ -246,6 +316,7 @@ struct vk_device_struct { bool pipeline_robustness; vk::Device device; uint32_t vendor_id; + vk::DriverId driver_id; vk_device_architecture architecture; vk_queue compute_queue; vk_queue transfer_queue; @@ -256,6 +327,7 @@ struct vk_device_struct { bool prefer_host_memory; bool float_controls_rte_fp16; bool subgroup_add; + bool subgroup_shuffle; bool integer_dot_product; @@ -265,8 +337,12 @@ struct vk_device_struct { bool subgroup_require_full_support; bool coopmat_support; - bool coopmat_acc_f32_support; - bool coopmat_acc_f16_support; + bool coopmat_acc_f32_support {}; + bool coopmat_acc_f16_support {}; + bool coopmat_bf16_support {}; + bool coopmat_support_16x16x16_f16acc {}; + bool coopmat_support_16x16x16_f32acc {}; + bool coopmat1_fa_support {}; uint32_t coopmat_m; uint32_t coopmat_n; uint32_t coopmat_k; @@ -290,8 +366,11 @@ struct vk_device_struct { // set to true to indicate that some shaders need to be compiled after the dryrun bool need_compiles {}; + vk::DescriptorSetLayout dsl; + vk_matmul_pipeline pipeline_matmul_f32 {}; vk_matmul_pipeline pipeline_matmul_f32_f16 {}; + vk_matmul_pipeline pipeline_matmul_bf16 {}; vk_matmul_pipeline2 pipeline_matmul_f16; vk_matmul_pipeline2 pipeline_matmul_f16_f32; @@ -300,6 +379,7 @@ struct vk_device_struct { vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT]; vk_matmul_pipeline pipeline_matmul_id_f32 {}; + vk_matmul_pipeline pipeline_matmul_id_bf16 {}; vk_matmul_pipeline2 pipeline_matmul_id_f16; vk_matmul_pipeline2 pipeline_matmul_id_f16_f32; @@ -318,11 +398,17 @@ struct vk_device_struct { vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_acc_f32; - vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat; - vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat; - vk_pipeline pipeline_sub_f32, pipeline_sub_f32_norepeat; - vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat; - vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat; + + // [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16] + vk_pipeline pipeline_add[2][2][2]; + vk_pipeline pipeline_add_norepeat[2][2][2]; + vk_pipeline pipeline_sub[2][2][2]; + vk_pipeline pipeline_sub_norepeat[2][2][2]; + vk_pipeline pipeline_mul[2][2][2]; + vk_pipeline pipeline_mul_norepeat[2][2][2]; + vk_pipeline pipeline_div[2][2][2]; + vk_pipeline pipeline_div_norepeat[2][2][2]; + vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; vk_pipeline pipeline_upscale_f32; vk_pipeline pipeline_scale_f32; @@ -332,8 +418,8 @@ struct vk_device_struct { vk_pipeline pipeline_clamp_f32; vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; - vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16; - vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16; + vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16; + vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16; vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_norm_f32; @@ -341,14 +427,17 @@ struct vk_device_struct { vk_pipeline pipeline_rms_norm_f32; vk_pipeline pipeline_rms_norm_back_f32; vk_pipeline pipeline_l2_norm_f32; - vk_pipeline pipeline_gelu_f32; - vk_pipeline pipeline_gelu_quick_f32; - vk_pipeline pipeline_silu_f32; - vk_pipeline pipeline_silu_back_f32; - vk_pipeline pipeline_relu_f32; + + // [src/dst 0=fp32,1=fp16] + vk_pipeline pipeline_gelu[2]; + vk_pipeline pipeline_gelu_quick[2]; + vk_pipeline pipeline_silu[2]; + vk_pipeline pipeline_relu[2]; + vk_pipeline pipeline_tanh[2]; + vk_pipeline pipeline_sigmoid[2]; + vk_pipeline pipeline_leaky_relu_f32; - vk_pipeline pipeline_tanh_f32; - vk_pipeline pipeline_sigmoid_f32; + vk_pipeline pipeline_silu_back_f32; vk_pipeline pipeline_diag_mask_inf_f32; vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; @@ -363,22 +452,39 @@ struct vk_device_struct { vk_pipeline pipeline_count_equal_i32; vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; vk_pipeline pipeline_timestep_embedding_f32; + vk_pipeline pipeline_conv_transpose_1d_f32; vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv7_f32; vk_pipeline pipeline_opt_step_adamw_f32; + vk_pipeline pipeline_conv2d_dw_whcn_f32; + vk_pipeline pipeline_conv2d_dw_cwhn_f32; // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} + vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2]; + + vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_split_k_reduce; std::unordered_map pipelines; - std::unordered_map pipeline_descriptor_set_requirements; std::vector> pinned_memory; @@ -390,9 +496,11 @@ struct vk_device_struct { #ifdef GGML_VULKAN_MEMORY_DEBUG std::unique_ptr memory_logger; #endif -#ifdef GGML_VULKAN_PERF + + // for GGML_VK_PERF_LOGGER std::unique_ptr perf_logger; -#endif + vk::QueryPool query_pool; + int32_t num_queries; ~vk_device_struct() { VK_LOG_DEBUG("destroy device " << name); @@ -401,10 +509,8 @@ struct vk_device_struct { ggml_vk_destroy_buffer(sync_staging); - device.destroyCommandPool(compute_queue.pool); - if (!single_queue) { - device.destroyCommandPool(transfer_queue.pool); - } + compute_queue.cmd_pool.destroy(device); + transfer_queue.cmd_pool.destroy(device); for (auto& pipeline : pipelines) { if (pipeline.second.expired()) { @@ -416,10 +522,26 @@ struct vk_device_struct { } pipelines.clear(); + device.destroyDescriptorSetLayout(dsl); + device.destroy(); } }; +void vk_command_pool::init(vk_device& device, vk_queue *q_) { + cmd_buffer_idx = 0; + q = q_; + + vk::CommandPoolCreateInfo command_pool_create_info(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), q->queue_family_index); + pool = device->device.createCommandPool(command_pool_create_info); +} + +void vk_command_pool::destroy(vk::Device& device) { + device.destroyCommandPool(pool); + pool = nullptr; + cmd_buffers.clear(); +} + struct vk_buffer_struct { vk::Buffer buffer = VK_NULL_HANDLE; vk::DeviceMemory device_memory = VK_NULL_HANDLE; @@ -654,6 +776,21 @@ struct vk_op_timestep_embedding_push_constants { uint32_t max_period; }; +struct vk_op_conv_transpose_1d_push_constants { + uint32_t Cout; + uint32_t Cin; + uint32_t K; + uint32_t L; + uint32_t KL; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb11; + uint32_t nb1; + + int32_t s0; +}; + struct vk_op_pool2d_push_constants { uint32_t IW; uint32_t IH; uint32_t OW; uint32_t OH; @@ -679,6 +816,24 @@ struct vk_op_rwkv_wkv7_push_constants { uint32_t H; }; +struct vk_op_conv2d_dw_push_constants { + uint32_t ne; + uint32_t batches; + uint32_t channels; + uint32_t dst_w; + uint32_t dst_h; + uint32_t src_w; + uint32_t src_h; + uint32_t knl_w; + uint32_t knl_h; + int32_t stride_x; + int32_t stride_y; + int32_t pad_x; + int32_t pad_y; + int32_t dilation_x; + int32_t dilation_y; +}; + struct vk_op_upscale_push_constants { uint32_t ne; uint32_t a_offset; uint32_t d_offset; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; @@ -704,7 +859,7 @@ struct vk_context_struct { std::vector in_memcpys; std::vector out_memcpys; - vk_queue * q; + vk_command_pool * p {}; }; typedef std::shared_ptr vk_context; typedef std::weak_ptr vk_context_ref; @@ -758,8 +913,6 @@ private: #define VK_LOG_MEMORY(msg) ((void) 0) #endif // GGML_VULKAN_MEMORY_DEBUG -#if defined(GGML_VULKAN_PERF) - class vk_perf_logger { public: void print_timings() { @@ -769,7 +922,7 @@ public: for (const auto& time : t.second) { total += time; } - std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " ms" << std::endl; + std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " us" << std::endl; } timings.clear(); @@ -798,7 +951,6 @@ public: private: std::map> timings; }; -#endif // GGML_VULKAN_PERF struct ggml_backend_vk_context { std::string name; @@ -818,6 +970,14 @@ struct ggml_backend_vk_context { vk_context_ref transfer_ctx; std::vector tensor_ctxs; + + std::vector descriptor_pools; + std::vector descriptor_sets; + uint32_t descriptor_set_idx {}; + uint32_t pipeline_descriptor_set_requirements {}; + + vk_command_pool compute_cmd_pool; + vk_command_pool transfer_cmd_pool; }; static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT @@ -888,6 +1048,8 @@ struct vk_instance_t { static bool vk_instance_initialized = false; static vk_instance_t vk_instance; +static bool vk_perf_logger_enabled = false; + #ifdef GGML_VULKAN_CHECK_RESULTS static size_t vk_skip_checks; static size_t vk_output_tensor; @@ -946,39 +1108,19 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")"); GGML_ASSERT(parameter_count > 0); + GGML_ASSERT(parameter_count <= MAX_PARAMETER_COUNT); GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast(spv_data)); pipeline->shader_module = device->device.createShaderModule(shader_module_create_info); - std::vector dsl_binding; - std::vector dsl_binding_flags; - for (uint32_t i = 0; i < parameter_count; i++) { - dsl_binding.push_back({i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute}); - dsl_binding_flags.push_back({}); - } - - vk::DescriptorSetLayoutBindingFlagsCreateInfo dslbfci = { dsl_binding_flags }; - vk::PushConstantRange pcr( vk::ShaderStageFlagBits::eCompute, 0, pipeline->push_constant_size ); - vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info( - {}, - dsl_binding); - descriptor_set_layout_create_info.setPNext(&dslbfci); - pipeline->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info); - - vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count * VK_DEVICE_DESCRIPTOR_POOL_SIZE); - vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size); - pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info)); - - pipeline->descriptor_set_idx = 0; - - vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), pipeline->dsl, pcr); + vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), device->dsl, pcr); pipeline->layout = device->device.createPipelineLayout(pipeline_layout_create_info); std::vector specialization_entries(specialization_constants.size()); @@ -1053,15 +1195,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) { VK_LOG_DEBUG("ggml_pipeline_destroy_pipeline(" << pipeline->name << ")"); - for (auto& pool : pipeline->descriptor_pools) { - device.destroyDescriptorPool(pool); - } - pipeline->descriptor_pools.clear(); - pipeline->descriptor_sets.clear(); - pipeline->descriptor_set_idx = 0; - - device.destroyDescriptorSetLayout(pipeline->dsl); - device.destroyPipelineLayout(pipeline->layout); device.destroyShaderModule(pipeline->shader_module); @@ -1069,97 +1202,77 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) device.destroyPipeline(pipeline->pipeline); } -static void ggml_pipeline_request_descriptor_sets(vk_device& device, vk_pipeline& pipeline, uint32_t n) { +static void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx, vk_pipeline& pipeline, uint32_t n) { VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")"); - device->pipeline_descriptor_set_requirements[pipeline->name] += n; + ctx->pipeline_descriptor_set_requirements += n; if (!pipeline->compiled) { pipeline->needed = true; - device->need_compiles = true; + ctx->device->need_compiles = true; } } -static void ggml_pipeline_allocate_descriptor_sets(vk_device& device) { - std::lock_guard guard(device->mutex); +static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx) { - for (auto& pair : device->pipeline_descriptor_set_requirements) { - vk_pipeline pipeline = device->pipelines.at(pair.first).lock(); - const uint64_t n = pair.second; + if (ctx->descriptor_sets.size() >= ctx->pipeline_descriptor_set_requirements) { + // Enough descriptors are available + return; + } - VK_LOG_DEBUG("ggml_pipeline_allocate_descriptor_sets(" << pipeline->name << ", " << n << ")"); + vk_device& device = ctx->device; - if (pipeline->descriptor_sets.size() >= pipeline->descriptor_set_idx + n) { - // Enough descriptors are available - continue; + uint32_t to_alloc = ctx->pipeline_descriptor_set_requirements - ctx->descriptor_sets.size(); + uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - ctx->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE; + uint32_t pool_idx = ctx->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE; + + while (to_alloc > 0) { + const uint32_t alloc_count = std::min(pool_remaining, to_alloc); + to_alloc -= alloc_count; + pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE; + + if (pool_idx >= ctx->descriptor_pools.size()) { + vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, MAX_PARAMETER_COUNT * VK_DEVICE_DESCRIPTOR_POOL_SIZE); + vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size); + ctx->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info)); } - uint32_t to_alloc = pipeline->descriptor_set_idx + n - pipeline->descriptor_sets.size(); - uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - pipeline->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE; - uint32_t pool_idx = pipeline->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE; - - while (to_alloc > 0) { - const uint32_t alloc_count = std::min(pool_remaining, to_alloc); - to_alloc -= alloc_count; - pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE; - - if (pool_idx >= pipeline->descriptor_pools.size()) { - vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, pipeline->parameter_count * VK_DEVICE_DESCRIPTOR_POOL_SIZE); - vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size); - pipeline->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info)); - } - - std::vector layouts(alloc_count); - for (uint32_t i = 0; i < alloc_count; i++) { - layouts[i] = pipeline->dsl; - } - vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(pipeline->descriptor_pools[pool_idx], alloc_count, layouts.data()); - std::vector sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info); - pipeline->descriptor_sets.insert(pipeline->descriptor_sets.end(), sets.begin(), sets.end()); - - pool_idx++; + std::vector layouts(alloc_count); + for (uint32_t i = 0; i < alloc_count; i++) { + layouts[i] = device->dsl; } + vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(ctx->descriptor_pools[pool_idx], alloc_count, layouts.data()); + std::vector sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info); + ctx->descriptor_sets.insert(ctx->descriptor_sets.end(), sets.begin(), sets.end()); + + pool_idx++; } } -static void ggml_pipeline_cleanup(vk_pipeline& pipeline) { - VK_LOG_DEBUG("ggml_pipeline_cleanup(" << pipeline->name << ")"); - pipeline->descriptor_set_idx = 0; -} - -static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_queue& q) { +static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_command_pool& p) { VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()"); - std::lock_guard guard(device->mutex); - if (q.cmd_buffers.size() > q.cmd_buffer_idx) { + if (p.cmd_buffers.size() > p.cmd_buffer_idx) { // Reuse command buffer - return q.cmd_buffers[q.cmd_buffer_idx++]; + return p.cmd_buffers[p.cmd_buffer_idx++]; } vk::CommandBufferAllocateInfo command_buffer_alloc_info( - q.pool, + p.pool, vk::CommandBufferLevel::ePrimary, 1); const std::vector cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info); auto buf = cmd_buffers.front(); - q.cmd_buffers.push_back(buf); - q.cmd_buffer_idx++; + p.cmd_buffers.push_back(buf); + p.cmd_buffer_idx++; return buf; } -static vk_submission ggml_vk_create_submission(vk_device& device, vk_queue& q, std::vector wait_semaphores, std::vector signal_semaphores) { - VK_LOG_DEBUG("ggml_vk_create_submission()"); - vk_submission s; - s.buffer = ggml_vk_create_cmd_buffer(device, q); - s.wait_semaphores = std::move(wait_semaphores); - s.signal_semaphores = std::move(signal_semaphores); - return s; -} - static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) { if (ctx->seqs.empty()) { if (fence) { - ctx->q->queue.submit({}, fence); + std::lock_guard guard(queue_mutex); + ctx->p->q->queue.submit({}, fence); } return; } @@ -1198,7 +1311,7 @@ static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) { tl_signal_vals.push_back({}); tl_signal_semaphores.push_back({}); for (size_t i = 0; i < submission.wait_semaphores.size(); i++) { - stage_flags[idx].push_back(ctx->q->stage_flags); + stage_flags[idx].push_back(ctx->p->q->stage_flags); tl_wait_vals[idx].push_back(submission.wait_semaphores[i].value); tl_wait_semaphores[idx].push_back(submission.wait_semaphores[i].s); } @@ -1228,7 +1341,8 @@ static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) { } } - ctx->q->queue.submit(submit_infos, fence); + std::lock_guard guard(queue_mutex); + ctx->p->q->queue.submit(submit_infos, fence); ctx->seqs.clear(); } @@ -1286,28 +1400,25 @@ static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_ q.queue_family_index = queue_family_index; q.transfer_only = transfer_only; - vk::CommandPoolCreateInfo command_pool_create_info_compute(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), queue_family_index); - q.pool = device->device.createCommandPool(command_pool_create_info_compute); - - q.cmd_buffer_idx = 0; + q.cmd_pool.init(device, &q); q.queue = device->device.getQueue(queue_family_index, queue_index); q.stage_flags = stage_flags; } -static vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_queue& q) { +static vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_command_pool& p) { vk_context result = std::make_shared(); VK_LOG_DEBUG("ggml_vk_create_context(" << result << ")"); ctx->gc.contexts.emplace_back(result); - result->q = &q; + result->p = &p; return result; } -static vk_context ggml_vk_create_temporary_context(vk_queue& q) { +static vk_context ggml_vk_create_temporary_context(vk_command_pool& p) { vk_context result = std::make_shared(); VK_LOG_DEBUG("ggml_vk_create_temporary_context(" << result << ")"); - result->q = &q; + result->p = &p; return result; } @@ -1340,15 +1451,29 @@ static vk::Event ggml_vk_create_event(ggml_backend_vk_context * ctx) { return ctx->gc.events[ctx->event_idx++]; } -static void ggml_vk_queue_cleanup(vk_device& device, vk_queue& q) { - VK_LOG_DEBUG("ggml_vk_queue_cleanup()"); - std::lock_guard guard(device->mutex); +static void ggml_vk_command_pool_cleanup(vk_device& device, vk_command_pool& p) { + VK_LOG_DEBUG("ggml_vk_command_pool_cleanup()"); // Requires command buffers to be done - device->device.resetCommandPool(q.pool); - q.cmd_buffer_idx = 0; + device->device.resetCommandPool(p.pool); + p.cmd_buffer_idx = 0; } +static void ggml_vk_queue_command_pools_cleanup(vk_device& device) { + VK_LOG_DEBUG("ggml_vk_queue_command_pools_cleanup()"); + + // Arbitrary frequency to cleanup/reuse command buffers + static constexpr uint32_t cleanup_frequency = 10; + + if (device->compute_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) { + ggml_vk_command_pool_cleanup(device, device->compute_queue.cmd_pool); + } + if (device->transfer_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) { + ggml_vk_command_pool_cleanup(device, device->transfer_queue.cmd_pool); + } +} + + static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) { for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) { vk::MemoryType memory_type = mem_props->memoryTypes[i]; @@ -1367,8 +1492,6 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device memory allocation limit"); } - std::lock_guard guard(device->mutex); - vk_buffer buf = std::make_shared(); if (size == 0) { @@ -1497,11 +1620,11 @@ static vk_subbuffer ggml_vk_subbuffer(vk_buffer& buf) { static void ggml_vk_sync_buffers(vk_context& ctx) { VK_LOG_DEBUG("ggml_vk_sync_buffers()"); - const bool transfer_queue = ctx->q->transfer_only; + const bool transfer_queue = ctx->p->q->transfer_only; ctx->s->buffer.pipelineBarrier( - ctx->q->stage_flags, - ctx->q->stage_flags, + ctx->p->q->stage_flags, + ctx->p->q->stage_flags, {}, { { { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) }, @@ -1520,29 +1643,70 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events ctx->s->buffer.waitEvents( events, - ctx->q->stage_flags, - ctx->q->stage_flags, + ctx->p->q->stage_flags, + ctx->p->q->stage_flags, {}, {}, {} ); } +enum FaCodePath { + FA_SCALAR, + FA_COOPMAT1, + FA_COOPMAT2, +}; + // number of rows/cols for flash attention shader static constexpr uint32_t flash_attention_num_small_rows = 32; -static std::array fa_rows_cols(uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { +static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; +static constexpr uint32_t scalar_flash_attention_num_large_rows = 8; + +// The FA coopmat1 shader assumes 16x16x16 matrix multiply support. +// 128 threads split into four subgroups, each subgroup does 1/4 +// of the Bc dimension. +static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16; +static constexpr uint32_t scalar_flash_attention_Bc = 64; +static constexpr uint32_t scalar_flash_attention_workgroup_size = 128; + +static uint32_t get_fa_num_small_rows(FaCodePath path) { + if (path == FA_COOPMAT2) { + return flash_attention_num_small_rows; + } else { + return scalar_flash_attention_num_small_rows; + } +} + +static std::array fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { GGML_UNUSED(clamp); + if (path == FA_SCALAR) { + if (small_rows) { + return {scalar_flash_attention_num_small_rows, 64}; + } else { + return {scalar_flash_attention_num_large_rows, 32}; + } + } + + if (path == FA_COOPMAT1) { + if (small_rows) { + return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc}; + } else { + return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc}; + } + } + // small rows, large cols if (small_rows) { - return {flash_attention_num_small_rows, 64}; + return {get_fa_num_small_rows(FA_COOPMAT2), 32}; } + // small cols to reduce register count if (ggml_is_quantized(type) || D == 256) { return {64, 32}; } return {64, 64}; -}; +} static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { @@ -1581,7 +1745,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec const uint32_t warps = warptile[0] / warptile[10]; const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size; - const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0; + const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0; const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0; const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size; @@ -1740,6 +1904,11 @@ static void ggml_vk_load_shaders(vk_device& device) { m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 }; s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 }; + // chip specific tuning + if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) { + m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; + } + l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 }; s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 }; @@ -1785,6 +1954,12 @@ static void ggml_vk_load_shaders(vk_device& device) { if (!device->pipeline_matmul_id_f32) { device->pipeline_matmul_id_f32 = std::make_shared(); } + if (!device->pipeline_matmul_bf16) { + device->pipeline_matmul_bf16 = std::make_shared(); + } + if (!device->pipeline_matmul_id_bf16) { + device->pipeline_matmul_id_bf16 = std::make_shared(); + } std::vector> compiles; auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, @@ -1820,65 +1995,75 @@ static void ggml_vk_load_shaders(vk_device& device) { parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); }; + auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { + return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1}; + }; + + auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { + // For large number of rows, 128 invocations seems to work best. + // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we + // can't use 256 for D==80. + // For scalar, use 128 (arbitrary) + uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1) + ? scalar_flash_attention_workgroup_size + : ((small_rows && (D % 32) == 0) ? 256 : 128); + auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows); + + // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. + // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. + const uint32_t D_lsb = D ^ (D & (D-1)); + uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4); + + // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads + GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0); + return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split}; + }; + +#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + +#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \ + CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256) + + CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (device->coopmat1_fa_support) { + CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1) + } +#endif #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (device->coopmat2) { - - auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { - return {fa_rows_cols(D, clamp, type, small_rows)[0], 1, 1}; - }; - - auto const &fa_spec_constants = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { - // For large number of rows, 128 invocations seems to work best. - // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we - // can't use 256 for D==80. - uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128; - auto rows_cols = fa_rows_cols(D, clamp, type, small_rows); - // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads - GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0); - return {wg_size, rows_cols[0], rows_cols[1], (D), clamp}; - }; - -#define CREATE_FA2(TYPE, NAMELC, D) \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \ - -#define CREATE_FA(TYPE, NAMELC) \ - CREATE_FA2(TYPE, NAMELC, 64) \ - CREATE_FA2(TYPE, NAMELC, 80) \ - CREATE_FA2(TYPE, NAMELC, 96) \ - CREATE_FA2(TYPE, NAMELC, 112) \ - CREATE_FA2(TYPE, NAMELC, 128) \ - CREATE_FA2(TYPE, NAMELC, 256) - - CREATE_FA(GGML_TYPE_F16, f16) - CREATE_FA(GGML_TYPE_Q4_0, q4_0) - CREATE_FA(GGML_TYPE_Q4_1, q4_1) - CREATE_FA(GGML_TYPE_Q5_0, q5_0) - CREATE_FA(GGML_TYPE_Q5_1, q5_1) - CREATE_FA(GGML_TYPE_Q8_0, q8_0) - // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently - //CREATE_FA(GGML_TYPE_Q2_K, q2_k) - //CREATE_FA(GGML_TYPE_Q3_K, q3_k) - //CREATE_FA(GGML_TYPE_Q4_K, q4_k) - //CREATE_FA(GGML_TYPE_Q5_K, q5_k) - //CREATE_FA(GGML_TYPE_Q6_K, q6_k) - //CREATE_FA(GGML_TYPE_IQ1_S, iq1_s) - //CREATE_FA(GGML_TYPE_IQ1_M, iq1_m) - //CREATE_FA(GGML_TYPE_IQ2_XXS, iq2_xxs) - //CREATE_FA(GGML_TYPE_IQ2_XS, iq2_xs) - //CREATE_FA(GGML_TYPE_IQ2_S, iq2_s) - //CREATE_FA(GGML_TYPE_IQ3_XXS, iq3_xxs) - //CREATE_FA(GGML_TYPE_IQ3_S, iq3_s) - //CREATE_FA(GGML_TYPE_IQ4_XS, iq4_xs) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl) + CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2) + } +#endif +#undef CREATE_FA2 #undef CREATE_FA +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + // Create 6 variants, {s,m,l}x{unaligned,aligned} #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ @@ -1894,27 +2079,37 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) - CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (device->coopmat_bf16_support) { + CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) + } +#endif + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0], matmul_q4_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1], matmul_q4_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0], matmul_q5_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1], matmul_q5_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0], matmul_q8_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K], matmul_q2_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K], matmul_q3_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K], matmul_q4_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K], matmul_q5_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K], matmul_q6_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_S], matmul_iq1_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_M], matmul_iq1_m_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S], matmul_iq2_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S], matmul_iq3_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (device->coopmat_bf16_support) { + CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) + } +#endif CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) @@ -1943,17 +2138,17 @@ static void ggml_vk_load_shaders(vk_device& device) { // Create 6 variants, {s,m,l}x{unaligned,aligned} #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \ // Create 2 variants, {f16,f32} accumulator #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ @@ -1968,54 +2163,64 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (device->coopmat_bf16_support) { + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ) + } +#endif if (device->coopmat_acc_f16_support) { - CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } else { - CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (device->coopmat_bf16_support) { + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + } +#endif if (device->coopmat_acc_f16_support) { CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); @@ -2080,13 +2285,19 @@ static void ggml_vk_load_shaders(vk_device& device) { if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ -#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ - if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ +#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l[TYPE]) { \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->l, #NAMELC "_f16acc_l", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + } \ + if (device->mul_mat ## ID ## _m[TYPE]) { \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->m, #NAMELC "_f16acc_m", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + } \ + if (device->mul_mat ## ID ## _s[TYPE]) { \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->s, #NAMELC "_f16acc_s", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + } \ // Create 2 variants, {f16,f32} accumulator #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ @@ -2098,34 +2309,36 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f16acc, matmul_iq1_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f16acc, matmul_iq1_m_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f16acc, matmul_iq2_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f16acc, matmul_iq2_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); - CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { - CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); } #endif @@ -2133,6 +2346,8 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); @@ -2172,19 +2387,21 @@ static void ggml_vk_load_shaders(vk_device& device) { if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ -#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ +#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -2208,11 +2425,11 @@ static void ggml_vk_load_shaders(vk_device& device) { #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { - CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); - CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); } #endif @@ -2220,6 +2437,8 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); @@ -2240,8 +2459,26 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -#undef CREATE_MM } + // reusing CREATE_MM from the fp32 path + if ((device->coopmat2 || device->coopmat_support) +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + && !device->coopmat_bf16_support +#endif + ) { + // use scalar tile sizes + l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; + m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 }; + s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 }; + + l_wg_denoms = {128, 128, 1 }; + m_wg_denoms = { 64, 64, 1 }; + s_wg_denoms = { 32, 32, 1 }; + + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id); + } +#undef CREATE_MM // mul mat vec @@ -2260,6 +2497,7 @@ static void ggml_vk_load_shaders(vk_device& device) { for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f32_f32_len, mul_mat_vec_bf16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); @@ -2282,6 +2520,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f16_f32_len, mul_mat_vec_bf16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); @@ -2305,6 +2544,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", mul_mat_vec_id_bf16_f32_len, mul_mat_vec_id_bf16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); @@ -2350,6 +2590,7 @@ static void ggml_vk_load_shaders(vk_device& device) { // get_rows ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16", get_rows_bf16_len, get_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); @@ -2367,6 +2608,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32", get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); @@ -2393,21 +2635,26 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true); } } - ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 9 * sizeof(uint32_t), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + if (device->float_controls_rte_fp16) { ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); @@ -2431,20 +2678,32 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); - ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); - ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16_norepeat, "add_f16_f32_f16_norepeat", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); + auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) { + std::string s; + s += std::string(src0_f16 ? "_f16" : "_f32"); + s += std::string(src1_f16 ? "_f16" : "_f32"); + s += std::string(dst_f16 ? "_f16" : "_f32"); + return s; + }; + +#define CREATE_BINARY(name, namemod, spec) \ + for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \ + #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \ + "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1); + + CREATE_BINARY(add, , {0}) + CREATE_BINARY(add, _norepeat, {1}) + CREATE_BINARY(sub, , {0}) + CREATE_BINARY(sub, _norepeat, {1}) + CREATE_BINARY(mul, , {0}) + CREATE_BINARY(mul, _norepeat, {1}) + CREATE_BINARY(div, , {0}) + CREATE_BINARY(div, _norepeat, {1}) +#undef CREATE_BINARY ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_sub_f32, "sub_f32", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); - ggml_vk_create_pipeline(device, device->pipeline_sub_f32_norepeat, "sub_f32_norepeat", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); - ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); - ggml_vk_create_pipeline(device, device->pipeline_div_f32_norepeat, "div_f32_norepeat", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); @@ -2464,14 +2723,20 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); +#define CREATE_UNARY(name) \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + + CREATE_UNARY(gelu) + CREATE_UNARY(gelu_quick) + CREATE_UNARY(silu) + CREATE_UNARY(relu) + CREATE_UNARY(tanh) + CREATE_UNARY(sigmoid) +#undef CREATE_UNARY + ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_sigmoid_f32, "sigmoid_f32", sigmoid_f32_len, sigmoid_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true); @@ -2515,6 +2780,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); @@ -2523,6 +2790,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); + for (auto &c : compiles) { c.wait(); } @@ -2542,9 +2812,9 @@ static vk_device ggml_vk_get_device(size_t idx) { #ifdef GGML_VULKAN_MEMORY_DEBUG device->memory_logger = std::unique_ptr(new vk_memory_logger()); #endif -#ifdef GGML_VULKAN_PERF - device->perf_logger = std::unique_ptr(new vk_perf_logger()); -#endif + if (vk_perf_logger_enabled) { + device->perf_logger = std::unique_ptr(new vk_perf_logger()); + } size_t dev_num = vk_instance.device_indices[idx]; @@ -2572,6 +2842,7 @@ static vk_device ggml_vk_get_device(size_t idx) { bool coopmat2_support = false; device->coopmat_support = false; device->integer_dot_product = false; + bool bfloat16_support = false; for (const auto& properties : ext_props) { if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { @@ -2588,19 +2859,28 @@ static vk_device ggml_vk_get_device(size_t idx) { pipeline_robustness = true; } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) { device->subgroup_size_control = true; +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_COOPMAT")) { device->coopmat_support = true; device->coopmat_m = 0; device->coopmat_n = 0; device->coopmat_k = 0; +#endif +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_COOPMAT2")) { coopmat2_support = true; +#endif #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { device->integer_dot_product = true; +#endif +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_BFLOAT16")) { + bfloat16_support = true; #endif } } @@ -2658,6 +2938,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device->physical_device.getProperties2(&props2); device->properties = props2.properties; device->vendor_id = device->properties.vendorID; + device->driver_id = driver_props.driverID; const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE"); @@ -2693,6 +2974,9 @@ static vk_device ggml_vk_get_device(size_t idx) { device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic); + device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle); + const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr; device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute; @@ -2787,6 +3071,17 @@ static vk_device ggml_vk_get_device(size_t idx) { } #endif +#if defined(VK_KHR_shader_bfloat16) + VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {}; + bfloat16_features.pNext = nullptr; + bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR; + if (bfloat16_support) { + last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features; + last_struct = (VkBaseOutStructure *)&bfloat16_features; + device_extensions.push_back("VK_KHR_shader_bfloat16"); + } +#endif + VkPhysicalDeviceMaintenance4Features maint4_features {}; maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES; if (maintenance4_support) { @@ -2825,6 +3120,11 @@ static vk_device ggml_vk_get_device(size_t idx) { #if defined(VK_KHR_cooperative_matrix) device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix; + + // coopmat1 fa shader currently assumes 32 invocations per subgroup + device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support && + device->subgroup_size_control && device->subgroup_min_size <= 32 && + device->subgroup_max_size >= 32; #endif if (coopmat2_support) { @@ -2959,6 +3259,9 @@ static vk_device ggml_vk_get_device(size_t idx) { // Only enable if shape is identical device->coopmat_acc_f32_support = true; } + if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) { + device->coopmat_support_16x16x16_f32acc = true; + } } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 && (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) { // coopmat sizes not set yet @@ -2971,6 +3274,9 @@ static vk_device ggml_vk_get_device(size_t idx) { // Only enable if shape is identical device->coopmat_acc_f16_support = true; } + if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) { + device->coopmat_support_16x16x16_f16acc = true; + } } } else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 && (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 && @@ -2984,6 +3290,25 @@ static vk_device ggml_vk_get_device(size_t idx) { device->coopmat_int_n = prop.NSize; device->coopmat_int_k = prop.KSize; } +#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR && + prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR && + prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR && + (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup + ) { + // coopmat sizes not set yet + if (device->coopmat_m == 0) { + device->coopmat_bf16_support = true; + device->coopmat_m = prop.MSize; + device->coopmat_n = prop.NSize; + device->coopmat_k = prop.KSize; + } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) { + // Only enable if shape is identical + device->coopmat_bf16_support = true; + } + } +#endif } if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) { @@ -2991,11 +3316,19 @@ static vk_device ggml_vk_get_device(size_t idx) { GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n"); device->coopmat_support = false; } + if (getenv("GGML_VK_DISABLE_BFLOAT16")) { + device->coopmat_bf16_support = false; + } } if (device->coopmat_support) { device_extensions.push_back("VK_KHR_cooperative_matrix"); } +#if defined(VK_KHR_shader_bfloat16) + if (device->coopmat_bf16_support) { + device_extensions.push_back("VK_KHR_shader_bfloat16"); + } +#endif #endif device->name = GGML_VK_NAME + std::to_string(idx); @@ -3045,6 +3378,22 @@ static vk_device ggml_vk_get_device(size_t idx) { } } + + std::vector dsl_binding; + std::vector dsl_binding_flags; + for (uint32_t i = 0; i < MAX_PARAMETER_COUNT; i++) { + dsl_binding.push_back({i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute}); + dsl_binding_flags.push_back({}); + } + + vk::DescriptorSetLayoutBindingFlagsCreateInfo dslbfci = { dsl_binding_flags }; + + vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info( + {}, + dsl_binding); + descriptor_set_layout_create_info.setPNext(&dslbfci); + device->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info); + ggml_vk_load_shaders(device); if (!device->single_queue) { @@ -3052,7 +3401,8 @@ static vk_device ggml_vk_get_device(size_t idx) { ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true); } else { // TODO: Use pointer or reference to avoid copy - device->transfer_queue = device->compute_queue; + device->transfer_queue.copyFrom(device->compute_queue); + device->transfer_queue.cmd_pool.init(device, &device->transfer_queue); } device->buffer_type = { @@ -3269,11 +3619,13 @@ static void ggml_vk_instance_init() { vk_instance.instance = vk::createInstance(instance_create_info); vk_instance_initialized = true; - size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size(); + vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr; // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan char * devices_env = getenv("GGML_VK_VISIBLE_DEVICES"); if (devices_env != nullptr) { + size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size(); + std::string devices(devices_env); std::replace(devices.begin(), devices.end(), ',', ' '); @@ -3289,9 +3641,9 @@ static void ggml_vk_instance_init() { } else { std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); - // Make sure at least one device exists + // If no vulkan devices are found, return early if (devices.empty()) { - std::cerr << "ggml_vulkan: Error: No devices found." << std::endl; + GGML_LOG_INFO("ggml_vulkan: No devices found.\n"); return; } @@ -3374,9 +3726,20 @@ static void ggml_vk_instance_init() { } } - // If no dedicated GPUs found, fall back to GPU 0 + // If no dedicated GPUs found, fall back to the first non-CPU device. + // If only CPU devices are available, return without devices. if (vk_instance.device_indices.empty()) { - vk_instance.device_indices.push_back(0); + for (size_t i = 0; i < devices.size(); i++) { + if (devices[i].getProperties().deviceType != vk::PhysicalDeviceType::eCpu) { + vk_instance.device_indices.push_back(i); + break; + } + } + } + + if (vk_instance.device_indices.empty()) { + GGML_LOG_INFO("ggml_vulkan: No devices found.\n"); + return; } } GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size()); @@ -3405,6 +3768,9 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { ctx->fence = ctx->device->device.createFence({}); ctx->almost_ready_fence = ctx->device->device.createFence({}); + ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue); + ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue); + #ifdef GGML_VULKAN_CHECK_RESULTS const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS"); vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks)); @@ -3445,13 +3811,16 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type } static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) { - VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ", " << prec << ")"); if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { return ctx->device->pipeline_matmul_f32; } if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { return ctx->device->pipeline_matmul_f32_f16; } + if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) { + return ctx->device->pipeline_matmul_bf16; + } if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) { if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { return ctx->device->pipeline_matmul_f16_f32.f16acc; @@ -3470,7 +3839,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte // MMQ if (src1_type == GGML_TYPE_Q8_1) { - vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc; + vk_matmul_pipeline pipelines = (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc; if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) { return nullptr; @@ -3510,9 +3879,12 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte if (ctx->device->coopmat2) { assert(src1_type == GGML_TYPE_F16); - return ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc; + return prec == GGML_PREC_DEFAULT ? ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f32acc; } - return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc; + if (ctx->device->coopmat_support) { + return (ctx->device->fp16 && ctx->device->coopmat_acc_f16_support && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc; + } + return (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc; } static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols) { @@ -3523,6 +3895,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * switch (a_type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -3555,6 +3928,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { return ctx->device->pipeline_matmul_id_f32; } + if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) { + return ctx->device->pipeline_matmul_id_bf16; + } if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) { if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { return ctx->device->pipeline_matmul_id_f16_f32.f16acc; @@ -3608,6 +3984,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context switch (a_type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -3759,9 +4136,9 @@ static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf } } -static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bool one_time = true) { +static vk_submission ggml_vk_begin_submission(vk_device& device, vk_command_pool& p, bool one_time = true) { vk_submission s; - s.buffer = ggml_vk_create_cmd_buffer(device, q); + s.buffer = ggml_vk_create_cmd_buffer(device, p); if (one_time) { s.buffer.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit }); } else { @@ -3771,7 +4148,33 @@ static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bo return s; } -static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array elements) { +template size_t push_constant_size(const T &t) { + static_assert(std::is_class::value, "T must be a struct/class"); + GGML_UNUSED(t); + return sizeof(T); +} +template size_t push_constant_size(const std::vector &t) { + GGML_UNUSED(t); + return sizeof(T) * t.size(); +} +template size_t push_constant_size(const std::array &t) { + GGML_UNUSED(t); + return sizeof(T) * N; +} + +template const T *push_constant_data(const T &t) { + static_assert(std::is_class::value, "T must be a struct/class"); + return &t; +} +template const T *push_constant_data(const std::vector &t) { + return t.data(); +} +template const T *push_constant_data(const std::array &t) { + return t.data(); +} + +template +static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list const& descriptor_buffer_infos, const T &push_constants, std::array elements) { const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]); const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]); const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]); @@ -3780,14 +4183,14 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), "; } std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))"); - GGML_ASSERT(pipeline->descriptor_set_idx < pipeline->descriptor_sets.size()); - GGML_ASSERT(descriptor_buffer_infos.size() == pipeline->parameter_count); + GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size()); + GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT); - vk::DescriptorSet& descriptor_set = pipeline->descriptor_sets[pipeline->descriptor_set_idx++]; + vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++]; vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() }; ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {}); - subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size, push_constants); + subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants)); subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline); subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute, pipeline->layout, @@ -3820,7 +4223,7 @@ static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) { ggml_vk_ctx_end(subctx); } - subctx->seqs.push_back({ ggml_vk_begin_submission(device, *subctx->q) }); + subctx->seqs.push_back({ ggml_vk_begin_submission(device, *subctx->p) }); subctx->s = subctx->seqs[subctx->seqs.size() - 1].data(); } @@ -4021,7 +4424,9 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width); } } else { - vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue); + std::lock_guard guard(dst->device->mutex); + + vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(dst->device, subctx); ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true); ggml_vk_ctx_end(subctx); @@ -4033,6 +4438,7 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * ggml_vk_submit(subctx, dst->device->fence); VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences"); dst->device->device.resetFences({ dst->device->fence }); + ggml_vk_queue_command_pools_cleanup(dst->device); } } @@ -4109,7 +4515,9 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_ memcpy(dst, (uint8_t *) src->ptr + offset, size); } else { - vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue); + std::lock_guard guard(src->device->mutex); + + vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(src->device, subctx); ggml_vk_buffer_read_async(subctx, src, offset, dst, size, true); ggml_vk_ctx_end(subctx); @@ -4117,6 +4525,7 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_ ggml_vk_submit(subctx, src->device->fence); VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences"); src->device->device.resetFences({ src->device->fence }); + ggml_vk_queue_command_pools_cleanup(src->device); for (auto& cpy : subctx->out_memcpys) { memcpy(cpy.dst, cpy.src, cpy.n); @@ -4136,15 +4545,17 @@ static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t ds static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { if (src->device == dst->device) { + std::lock_guard guard(src->device->mutex); VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")"); // Copy within the device - vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue); + vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(src->device, subctx); ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size); ggml_vk_ctx_end(subctx); ggml_vk_submit(subctx, src->device->fence); VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences"); src->device->device.resetFences({ src->device->fence }); + ggml_vk_queue_command_pools_cleanup(src->device); } else { VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")"); // Copy device to device @@ -4169,7 +4580,8 @@ static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) { VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")"); - vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue); + std::lock_guard guard(dst->device->mutex); + vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(dst->device, subctx); subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c); ggml_vk_ctx_end(subctx); @@ -4177,6 +4589,7 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz ggml_vk_submit(subctx, dst->device->fence); VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_memset waitForFences"); dst->device->device.resetFences({ dst->device->fence }); + ggml_vk_queue_command_pools_cleanup(dst->device); } static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline) { @@ -4230,6 +4643,8 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_l : mmp->l; + + GGML_UNUSED(src1_type); } static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) { @@ -4248,7 +4663,7 @@ static void ggml_vk_matmul( ggml_vk_sync_buffers(subctx); if (split_k == 1) { const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch }); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch }); return; } @@ -4256,10 +4671,10 @@ static void ggml_vk_matmul( const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n }; // Make sure enough workgroups get assigned for split k to work - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); ggml_vk_sync_buffers(subctx); const std::array pc2 = { (uint32_t)(m * n * batch), split_k }; - ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 }); + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 }); } static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) { @@ -4307,7 +4722,7 @@ static void ggml_vk_matmul_id( ggml_vk_sync_buffers(subctx); const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, nei0, nei1, nbi1, ne11, padded_n }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as }); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, pc, { m, nei1, n_as }); } static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) { @@ -4343,6 +4758,20 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_cpy_f16_f16; } } + if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f16_f32; + } else { + return ctx->device->pipeline_cpy_f16_f32; + } + } + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_BF16) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_bf16; + } else { + return ctx->device->pipeline_cpy_f32_bf16; + } + } if (src->type == GGML_TYPE_F32) { switch (to) { case GGML_TYPE_Q4_0: @@ -4371,6 +4800,19 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const } } + if (src->type == to) { + // Copy two or four bytes at a time, depending on block size. + // For quantized types, we scale by block size/type size. But + // this path is also used for bf16->bf16 for example, where the + // type size must be exactly 2 or 4. + GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4); + if ((ggml_type_size(src->type) % 4) == 0) { + return ctx->device->pipeline_contig_cpy_f32_f32; + } else { + return ctx->device->pipeline_contig_cpy_f16_f16; + } + } + std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl; GGML_ABORT("fatal error"); } @@ -4401,7 +4843,7 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& }; init_pushconst_fastdiv(pc); ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements); } static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) { @@ -4420,7 +4862,7 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(uint32_t), &ne, { ne, 1, 1 }); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array{ne}, { ne, 1, 1 }); } static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -4470,8 +4912,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || !ggml_vk_dim01_contiguous(src0); const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || + (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) || !ggml_vk_dim01_contiguous(src1); + // If src0 is BF16, try to use a BF16 x BF16 multiply + ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16; + const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0; @@ -4481,25 +4927,25 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub if (mmp == nullptr) { // Fall back to f16 dequant mul mat - mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]); + mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]); quantize_y = false; } const bool qx_needs_dequant = mmp == nullptr || x_non_contig; - const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig); + const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig); if (qx_needs_dequant) { // Fall back to dequant + f16 mulmat - mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]); + mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]); } // Not implemented GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT - const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type))); + const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type))); const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8; - vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)); + vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)); // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11; @@ -4520,12 +4966,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub vk_pipeline to_q8_1 = nullptr; if (x_non_contig) { - to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16); + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type); } else { to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); } if (y_non_contig) { - to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16); + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type); } else { to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); } @@ -4557,18 +5003,18 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub } // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); if (qx_needs_dequant) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); } if (qy_needs_dequant) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); } if (quantize_y) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_q8_1, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); } if (split_k > 1) { - ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1); + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, 1); } return; } @@ -4616,7 +5062,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub } else if (qx_needs_dequant) { const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); + ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); } if (y_non_contig) { ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); @@ -4750,12 +5196,12 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& // Request descriptor sets if (qx_needs_dequant) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); } if (qy_needs_dequant) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); } - ggml_pipeline_request_descriptor_sets(ctx->device, dmmv, 1); + ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); return; } @@ -4832,7 +5278,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} }, - sizeof(vk_mat_vec_push_constants), &pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z }); + pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z }); } static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -4888,7 +5334,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c if (dryrun) { // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1); + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1); return; } @@ -4920,7 +5366,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c } ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, workgroups_z }); + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1, (uint32_t)ne01, workgroups_z }); } static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -4942,6 +5388,8 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con const uint64_t nb01 = src0->nb[1]; const uint64_t nb02 = src0->nb[2]; + const uint64_t nb12 = src1->nb[2]; + // const uint64_t ne10 = src1->ne[0]; const uint64_t ne11 = src1->ne[1]; const uint64_t ne12 = src1->ne[2]; @@ -4967,6 +5415,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t); const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t); + const uint32_t channel_stride_y = nb12 / sizeof(float); const uint64_t qx_sz = ggml_nbytes(src0); const uint64_t qy_sz = ggml_nbytes(src1); @@ -4974,7 +5423,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con if (dryrun) { // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1); + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1); return; } @@ -4997,10 +5446,10 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset; // compute - const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, (uint32_t)(ne12 / ne02), (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) }; + const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) }; ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, - { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); + { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); } static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -5022,7 +5471,7 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four) // when ne12 and ne13 are one. } else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) && - (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) { + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) { ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun); } else { ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun); @@ -5049,7 +5498,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const uint64_t nei0 = ids->ne[0]; const uint64_t nei1 = ids->ne[1]; - GGML_ASSERT(nei0 * nei1 <= 3072); + GGML_ASSERT(nei0 * nei1 <= 4096); const uint32_t nbi1 = ids->nb[1]; const uint32_t nbi2 = ids->nb[2]; @@ -5090,27 +5539,31 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || !ggml_vk_dim01_contiguous(src0); const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || + (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) || !ggml_vk_dim01_contiguous(src1); + // If src0 is BF16, try to use a BF16 x BF16 multiply + ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16; + const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; - vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]); + vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]); const bool qx_needs_dequant = mmp == nullptr || x_non_contig; - const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig; + const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig; if (qx_needs_dequant) { // Fall back to dequant + f16 mulmat - mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]); + mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]); } // Not implemented GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT - const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type)); + const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type)); const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8; - vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type); + vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type); // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11; @@ -5129,12 +5582,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& vk_pipeline to_fp16_vk_1 = nullptr; if (x_non_contig) { - to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16); + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type); } else { to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); } if (y_non_contig) { - to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16); + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type); } else { to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); } @@ -5157,12 +5610,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& } // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); if (qx_needs_dequant) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); } if (qy_needs_dequant) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); } return; } @@ -5212,7 +5665,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, - { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc.size() * sizeof(uint32_t), pc.data(), { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); + { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); } if (y_non_contig) { ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); @@ -5351,12 +5804,12 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte // Request descriptor sets if (qx_needs_dequant) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_0, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); } if (qy_needs_dequant) { - ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); } - ggml_pipeline_request_descriptor_sets(ctx->device, dmmv, 1); + ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); return; } @@ -5432,7 +5885,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23}, vk_subbuffer{ d_ids, ids_buf_offset, ids_sz } }, - sizeof(vk_mat_vec_id_push_constants), &pc, { groups_x, (uint32_t)nei0, groups_z }); + pc, { groups_x, (uint32_t)nei0, groups_z }); } static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { @@ -5444,6 +5897,36 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } } +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) { + // Needs to be kept up to date on shader changes + const uint32_t wg_size = scalar_flash_attention_workgroup_size; + const uint32_t Br = scalar_flash_attention_num_large_rows; + const uint32_t Bc = scalar_flash_attention_Bc; + + const uint32_t acctype = f32acc ? 4 : 2; + const uint32_t f16vec4 = 8; + + const uint32_t tmpsh = wg_size * sizeof(float); + const uint32_t tmpshv4 = wg_size * 4 * acctype; + + const uint32_t Qf = Br * (D / 4 + 2) * f16vec4; + + const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br; + const uint32_t sfsh = Bc * sfshstride * acctype; + + const uint32_t kshstride = D / 4 + 2; + const uint32_t ksh = Bc * kshstride * f16vec4; + + const uint32_t slope = Br * sizeof(float); + + const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope; + const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; + + VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported); + + return supported; +} + static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3]; std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3]; @@ -5494,20 +5977,110 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx assert(q->type == GGML_TYPE_F32); assert(k->type == v->type); - vk_pipeline *pipelines; - // XXX TODO other backends may be changing accumulator precision to default to f32 soon - bool f32acc = dst->op_params[3] == GGML_PREC_F32; - bool small_rows = N <= flash_attention_num_small_rows; - switch (D) { - case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break; - case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break; - case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break; - case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break; - case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break; - case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break; + FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 : + ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; + + if (path == FA_COOPMAT1) { + const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) || + (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc); + + const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32); + + if (!coopmat_shape_supported || !coopmat_shmem_supported) { + path = FA_SCALAR; + } + } + + uint32_t gqa_ratio = 1; + uint32_t qk_ratio = neq2 / nek2; + uint32_t workgroups_x = (uint32_t)neq1; + uint32_t workgroups_y = (uint32_t)neq2; + uint32_t workgroups_z = (uint32_t)neq3; + + // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. + // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). + uint32_t max_gqa; + switch (path) { + case FA_SCALAR: + case FA_COOPMAT1: + // We may switch from coopmat1 to scalar, so use the scalar limit for both + max_gqa = scalar_flash_attention_num_large_rows; + break; + case FA_COOPMAT2: + max_gqa = get_fa_num_small_rows(FA_COOPMAT2); + break; default: - assert(!"unsupported D value"); - return; + GGML_ASSERT(0); + } + + if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa && + qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) { + // grouped query attention - make the N dimension equal to gqa_ratio, reduce + // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 + // and change addressing calculations to index Q's dimension 2. + gqa_ratio = qk_ratio; + N = gqa_ratio; + workgroups_y /= N; + } + + vk_pipeline *pipelines; + bool small_rows = N <= get_fa_num_small_rows(path); + + // coopmat1 does not actually support "small rows" (it needs 16 rows). + // So use scalar instead. + if (small_rows && path == FA_COOPMAT1) { + path = FA_SCALAR; + } + + // scalar is faster than coopmat2 when N==1 + if (N == 1 && path == FA_COOPMAT2) { + path = FA_SCALAR; + } + + bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; + + switch (path) { + case FA_SCALAR: + switch (D) { + case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break; + case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break; + case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break; + case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break; + case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break; + case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break; + default: + GGML_ASSERT(!"unsupported D value"); + return; + } + break; + case FA_COOPMAT1: + switch (D) { + case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break; + case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break; + case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break; + case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break; + case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break; + case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break; + default: + GGML_ASSERT(!"unsupported D value"); + return; + } + break; + case FA_COOPMAT2: + switch (D) { + case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break; + case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break; + case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break; + case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break; + case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break; + case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break; + default: + GGML_ASSERT(!"unsupported D value"); + return; + } + break; + default: + GGML_ASSERT(0); } assert(pipelines); @@ -5525,27 +6098,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_pipeline pipeline = pipelines[aligned]; assert(pipeline); - uint32_t gqa_ratio = 1; - uint32_t qk_ratio = neq2 / nek2; - uint32_t workgroups_x = (uint32_t)neq1; - uint32_t workgroups_y = (uint32_t)neq2; - uint32_t workgroups_z = (uint32_t)neq3; - - if (N == 1 && qk_ratio > 1 && is_pow2(qk_ratio) && gqa_ratio <= flash_attention_num_small_rows && - qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) { - // grouped query attention - make the N dimension equal to gqa_ratio, reduce - // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 - // and change addressing calculations to index Q's dimension 2. - gqa_ratio = qk_ratio; - N = gqa_ratio; - workgroups_y /= N; - } - uint32_t split_kv = KV; uint32_t split_k = 1; - if (gqa_ratio > 1 && ctx->device->shader_core_count > 0) { - GGML_ASSERT(workgroups_x == 1); + // Use a placeholder core count if one isn't available. split_k is a big help for perf. + const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; + + // Try to use split_k when KV is large enough to be worth the overhead + if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) { // Try to run two workgroups per SM. split_k = ctx->device->shader_core_count * 2 / workgroups_y; if (split_k > 1) { @@ -5569,9 +6129,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx if (dryrun) { // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); if (split_k > 1) { - ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_flash_attn_split_k_reduce, 1); + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1); } return; } @@ -5675,7 +6235,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // there's no more than one tile of rows (i.e. workgroups_x would have been // one). We reuse workgroups_x to mean the number of splits, so we need to // cancel out the divide by wg_denoms[0]. - sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); + pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); ggml_vk_sync_buffers(subctx); const std::array pc2 = { D, (uint32_t)ne1, split_k }; @@ -5684,7 +6244,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, }, - pc2.size() * uint32_t{sizeof(uint32_t)}, pc2.data(), { (uint32_t)ne1, 1, 1 }); + pc2, { (uint32_t)ne1, 1, 1 }); } else { ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { @@ -5694,7 +6254,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, }, - sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z }); + pc, { workgroups_x, workgroups_y, workgroups_z }); } } @@ -5715,26 +6275,37 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_ADD: - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f32_norepeat : ctx->device->pipeline_add_f32; - } - if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { - return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16; - } - return nullptr; case GGML_OP_SUB: - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_f32_norepeat : ctx->device->pipeline_sub_f32; - } - return nullptr; case GGML_OP_MUL: - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32; - } - return nullptr; case GGML_OP_DIV: - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_f32_norepeat : ctx->device->pipeline_div_f32; + if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) || + (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) || + (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16)) { + return nullptr; + } + switch (op) { + case GGML_OP_ADD: + { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + case GGML_OP_SUB: + { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_norepeat : ctx->device->pipeline_sub; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + case GGML_OP_MUL: + { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_norepeat : ctx->device->pipeline_mul; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + case GGML_OP_DIV: + { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_norepeat : ctx->device->pipeline_div; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + default: + break; } return nullptr; case GGML_OP_CONCAT: @@ -5828,37 +6399,25 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_UNARY: + if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) || + (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) || + (src0->type != dst->type)) { + return nullptr; + } + switch (ggml_get_unary_op(dst)) { case GGML_UNARY_OP_SILU: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_silu_f32; - } - break; + return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_GELU: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_gelu_f32; - } - break; + return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_GELU_QUICK: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_gelu_quick_f32; - } - break; + return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_RELU: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_relu_f32; - } - break; + return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_TANH: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_tanh_f32; - } - break; + return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_SIGMOID: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_sigmoid_f32; - } - break; + return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16]; default: break; } @@ -5956,6 +6515,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_timestep_embedding_f32; } return nullptr; + case GGML_OP_CONV_TRANSPOSE_1D: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_conv_transpose_1d_f32; + } + return nullptr; case GGML_OP_POOL_2D: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_pool2d_f32; @@ -5981,6 +6545,15 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_leaky_relu_f32; } return nullptr; + case GGML_OP_CONV_2D_DW: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + if (ggml_is_contiguous(src1)) { + return ctx->device->pipeline_conv2d_dw_whcn_f32; + } else if (ggml_is_contiguous_channels(src1)) { + return ctx->device->pipeline_conv2d_dw_cwhn_f32; + } + } + return nullptr; default: return nullptr; } @@ -6006,6 +6579,9 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: case GGML_OP_ROPE: + case GGML_OP_RMS_NORM: + case GGML_OP_CONV_2D_DW: + case GGML_OP_IM2COL: return true; default: return false; @@ -6118,7 +6694,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } if (dryrun) { - ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); return; } @@ -6216,7 +6792,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co switch (op) { case GGML_OP_NORM: - case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM_BACK: case GGML_OP_L2_NORM: case GGML_OP_SOFT_MAX: @@ -6233,6 +6808,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { nr, 1, 1 }; } } break; + case GGML_OP_RMS_NORM: + elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 }; + break; + case GGML_OP_SUM: // We use GGML_OP_SUM_ROWS with 1 row. elements = { 1, 1, 1 }; @@ -6275,6 +6854,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co uint32_t half_ceil = (dim + 1) / 2; elements = { half_ceil, (uint32_t)src0->ne[0], 1 }; } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + elements = {uint32_t(src0->ne[1]), 1, 1}; // parallelize in {Cout, 1, 1} + } break; case GGML_OP_POOL_2D: { const uint32_t N = dst->ne[3]; @@ -6299,8 +6882,18 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_CONCAT: case GGML_OP_UPSCALE: case GGML_OP_UNARY: + case GGML_OP_CONV_2D_DW: { - const uint32_t ne = ggml_nelements(dst); + uint32_t ne = ggml_nelements(dst); + if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) { + // Convert from number of logical elements to 2- or 4-byte units. + ne /= ggml_blck_size(src0->type); + if ((ggml_type_size(src0->type) % 4) == 0) { + ne *= ggml_type_size(src0->type) / 4; + } else { + ne *= ggml_type_size(src0->type) / 2; + } + } if (ne > 262144) { elements = { 512, 512, CEIL_DIV(ne, 262144) }; } else if (ne > 512) { @@ -6339,7 +6932,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) { // Empty src2 is possible in rope, but the shader needs a buffer vk_subbuffer subbuf_z; @@ -6350,26 +6943,26 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co } ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (op == GGML_OP_IM2COL) { // im2col uses only src1 and dst buffers ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (op == GGML_OP_COUNT_EQUAL) { ggml_vk_sync_buffers(subctx); // count_equal assumes that destination buffer is initialized with zeroes ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz); ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (use_src2) { ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else if (use_src1) { ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } else { ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); } } @@ -6482,7 +7075,7 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx GGML_ASSERT(pipeline != nullptr); if (dryrun) { - ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); return; } @@ -6538,7 +7131,7 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] }, vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] }, vk_subbuffer{ d_D, dst_offset, dst_size } - }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements); + }, pc, elements); } else if (version == 7) { ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] }, @@ -6549,7 +7142,7 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] }, vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] }, vk_subbuffer{ d_D, dst_offset, dst_size } - }, sizeof(vk_op_rwkv_wkv7_push_constants), &pc, elements); + }, pc, elements); } else { // shouldn't happen GGML_ASSERT(false); @@ -6621,7 +7214,7 @@ static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_cont GGML_ASSERT(pipeline != nullptr); if (dryrun) { - ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); return; } @@ -6686,7 +7279,7 @@ static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_cont vk_subbuffer{ d_GM, gm_offset, gm_size }, vk_subbuffer{ d_GV, gv_offset, gv_size }, vk_subbuffer{ d_P, p_offset, p_size }, - }, sizeof(vk_op_push_constants), &pc, elements); + }, pc, elements); } static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { @@ -6850,8 +7443,19 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t dst_type_size = ggml_type_size(dst->type); + uint32_t ne = (uint32_t)ggml_nelements(src0); + if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) { + // Convert from number of logical elements to 2- or 4-byte units. + ne /= ggml_blck_size(src0->type); + if ((ggml_type_size(src0->type) % 4) == 0) { + ne *= ggml_type_size(src0->type) / 4; + } else { + ne *= ggml_type_size(src0->type) / 2; + } + } + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, { - (uint32_t)ggml_nelements(src0), + ne, (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, 0, @@ -6883,7 +7487,17 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + op_params[0], 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); } static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -7047,6 +7661,37 @@ static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context }, dryrun); } +static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + // src0: (K, Cout, Cin, 1) -- kernel + // src1: (L, Cin, 1, 1) -- input + // dst: (*, Cout, 1, 1) + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + + const int32_t s0 = dst->op_params[0]; + + vk_op_conv_transpose_1d_push_constants p{}; + p.Cout = static_cast(ne01); + p.Cin = static_cast(ne02); + p.K = static_cast(ne00); + p.L = static_cast(ne10); + p.KL = static_cast(ne0); + p.nb01 = static_cast(nb01 / nb00); + p.nb02 = static_cast(nb02 / nb00); + p.nb11 = static_cast(nb11 / nb10); + p.nb1 = static_cast(nb1 / nb0); + p.s0 = static_cast(s0); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p), dryrun); +} + static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { uint32_t op = static_cast(dst->op_params[0]); const int32_t k1 = dst->op_params[1]; @@ -7075,6 +7720,30 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c }, dryrun); } +static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + vk_op_conv2d_dw_push_constants p{}; + p.ne = ggml_nelements(dst); + p.channels = dst->ne[2]; + p.batches = dst->ne[3]; + p.dst_w = dst->ne[0]; + p.dst_h = dst->ne[1]; + p.src_w = src1->ne[0]; + p.src_h = src1->ne[1]; + p.knl_w = src0->ne[0]; + p.knl_h = src0->ne[1]; + p.stride_x = dst->op_params[0]; + p.stride_y = dst->op_params[1]; + p.pad_x = dst->op_params[2]; + p.pad_y = dst->op_params[3]; + p.dilation_x = dst->op_params[4]; + p.dilation_y = dst->op_params[5]; + + GGML_ASSERT(src0->ne[3] == p.channels); + GGML_ASSERT(src1->ne[3] == p.batches); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun); +} + static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { const float * op_params = (const float *)dst->op_params; ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun); @@ -7223,9 +7892,9 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t } } - ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it); + ggml_pipeline_request_descriptor_sets(ctx, p, num_it); if (split_k > 1) { - ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it); + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it); if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) { // Resize buffer @@ -7240,7 +7909,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t ggml_vk_load_shaders(ctx->device); } - ggml_pipeline_allocate_descriptor_sets(ctx->device); + ggml_pipeline_allocate_descriptor_sets(ctx); vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); @@ -7282,7 +7951,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t ggml_vk_buffer_write(d_X, 0, x, sizeof(X_TYPE) * k * m * batch); ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch); - vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); ggml_vk_ctx_begin(ctx->device, subctx); for (size_t i = 0; i < num_it; i++) { ggml_vk_matmul( @@ -7298,6 +7967,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t ggml_vk_submit(subctx, ctx->fence); VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences"); ctx->device->device.resetFences({ ctx->fence }); + ggml_vk_queue_command_pools_cleanup(ctx->device); auto end = std::chrono::high_resolution_clock::now(); double time = std::chrono::duration_cast(end-begin).count() / 1000.0; @@ -7399,16 +8069,13 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t free(d_chk); - ggml_vk_queue_cleanup(ctx->device, ctx->device->transfer_queue); - ggml_vk_queue_cleanup(ctx->device, ctx->device->compute_queue); + ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); + ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); ggml_vk_destroy_buffer(d_X); ggml_vk_destroy_buffer(d_Y); ggml_vk_destroy_buffer(d_D); - ggml_pipeline_cleanup(p); - ggml_pipeline_cleanup(ctx->device->pipeline_matmul_split_k_reduce); - free(x); free(y); free(d); @@ -7486,20 +8153,20 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ ggml_vk_quantize_data(x, qx, ne, quant); ggml_vk_dequantize_data(qx, x_ref, ne, quant); - ggml_pipeline_request_descriptor_sets(ctx->device, p, 1); + ggml_pipeline_request_descriptor_sets(ctx, p, 1); if (ctx->device->need_compiles) { ggml_vk_load_shaders(ctx->device); } - ggml_pipeline_allocate_descriptor_sets(ctx->device); + ggml_pipeline_allocate_descriptor_sets(ctx); ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); - vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); ggml_vk_ctx_begin(ctx->device, subctx); const std::vector pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne }; - ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc.size() * sizeof(int), pc.data(), { (uint32_t)ne, 1, 1}); + ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc, { (uint32_t)ne, 1, 1}); ggml_vk_ctx_end(subctx); auto begin = std::chrono::high_resolution_clock::now(); @@ -7507,6 +8174,7 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ ggml_vk_submit(subctx, ctx->fence); VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences"); ctx->device->device.resetFences({ ctx->fence }); + ggml_vk_queue_command_pools_cleanup(ctx->device); auto end = std::chrono::high_resolution_clock::now(); @@ -7586,17 +8254,17 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ // // vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant); // -// ggml_pipeline_request_descriptor_sets(ctx->device, p, 1); +// ggml_pipeline_request_descriptor_sets(ctx, p, 1); // // if (ctx->device->need_compiles) { // ggml_vk_load_shaders(ctx->device); // } // -// ggml_pipeline_allocate_descriptor_sets(ctx->device); +// ggml_pipeline_allocate_descriptor_sets(ctx); // // ggml_vk_buffer_write(x_buf, 0, x, x_sz); // -// vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); +// vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); // ggml_vk_ctx_begin(ctx->device, subctx); // ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(x_buf), ggml_vk_subbuffer(qx_buf), ne); // ggml_vk_ctx_end(subctx); @@ -7606,6 +8274,7 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ // ggml_vk_submit(subctx, ctx->fence); // VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_quantize waitForFences"); // ctx->device->device.resetFences({ ctx->fence }); +// ggml_vk_queue_command_pools_cleanup(ctx->device); // // auto end = std::chrono::high_resolution_clock::now(); // @@ -7745,9 +8414,9 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, // y[i] = i % k; } - ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it); + ggml_pipeline_request_descriptor_sets(ctx, p, num_it); if (split_k > 1) { - ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, num_it); + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it); if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) { // Resize buffer @@ -7758,19 +8427,19 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, } } if (mmq) { - ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_quantize_q8_1, num_it); + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_quantize_q8_1, num_it); } if (ctx->device->need_compiles) { ggml_vk_load_shaders(ctx->device); } - ggml_pipeline_allocate_descriptor_sets(ctx->device); + ggml_pipeline_allocate_descriptor_sets(ctx); ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); ggml_vk_buffer_write(y_buf, 0, y, y_sz); - vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); ggml_vk_ctx_begin(ctx->device, subctx); if (mmq) { for (size_t i = 0; i < num_it; i++) { @@ -7799,6 +8468,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, ggml_vk_submit(subctx, ctx->fence); VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences"); ctx->device->device.resetFences({ ctx->fence }); + ggml_vk_queue_command_pools_cleanup(ctx->device); auto end = std::chrono::high_resolution_clock::now(); @@ -8094,7 +8764,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D_DW: case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: case GGML_OP_LEAKY_RELU: @@ -8111,7 +8783,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod if (!dryrun) { if (ctx->compute_ctx.expired()) { - compute_ctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); ctx->compute_ctx = compute_ctx; ggml_vk_ctx_begin(ctx->device, compute_ctx); } else { @@ -8157,13 +8829,15 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D_DW: case GGML_OP_LEAKY_RELU: { // These operations all go through ggml_vk_op_f32, so short-circuit and // do the only thing needed for the dryrun. vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op); - ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); return false; } default: @@ -8327,10 +9001,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_TIMESTEP_EMBEDDING: ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_CONV_TRANSPOSE_1D: + ggml_vk_conv_transpose_1d(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_POOL_2D: ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_CONV_2D_DW: + ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_LEAKY_RELU: ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun); @@ -8374,7 +9056,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod ctx->tensor_ctxs[node_idx] = compute_ctx; -#if defined(GGML_VULKAN_CHECK_RESULTS) || defined(GGML_VULKAN_PERF) +#if defined(GGML_VULKAN_CHECK_RESULTS) // Force context reset on each node so that each tensor ends up in its own context // and can be run and compared to its CPU equivalent separately last_node = true; @@ -8451,7 +9133,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D_DW: case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: case GGML_OP_LEAKY_RELU: @@ -8545,19 +9229,8 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { } ctx->gc.temp_buffers.clear(); - for (auto& dsr : ctx->device->pipeline_descriptor_set_requirements) { - vk_pipeline_ref plr = ctx->device->pipelines[dsr.first]; - - if (plr.expired()) { - continue; - } - - vk_pipeline pl = plr.lock(); - ggml_pipeline_cleanup(pl); - } - - ggml_vk_queue_cleanup(ctx->device, ctx->device->compute_queue); - ggml_vk_queue_cleanup(ctx->device, ctx->device->transfer_queue); + ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); + ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) { ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s }); @@ -8578,7 +9251,8 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { ctx->tensor_ctxs.clear(); ctx->gc.contexts.clear(); - ctx->device->pipeline_descriptor_set_requirements.clear(); + ctx->pipeline_descriptor_set_requirements = 0; + ctx->descriptor_set_idx = 0; } // Clean up on backend free @@ -8605,6 +9279,15 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ctx->device->device.destroyFence(ctx->fence); ctx->device->device.destroyFence(ctx->almost_ready_fence); + + for (auto& pool : ctx->descriptor_pools) { + ctx->device->device.destroyDescriptorPool(pool); + } + ctx->descriptor_pools.clear(); + ctx->descriptor_sets.clear(); + + ctx->compute_cmd_pool.destroy(ctx->device->device); + ctx->transfer_cmd_pool.destroy(ctx->device->device); } static int ggml_vk_get_device_count() { @@ -8792,8 +9475,7 @@ static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_ try { ptr = ggml_vk_host_malloc(vk_instance.devices[0], size); } catch (vk::SystemError& e) { - std::cerr << "ggml_vulkan: Failed to allocate pinned memory." << std::endl; - std::cerr << "ggml_vulkan: " << e.what() << std::endl; + GGML_LOG_WARN("ggml_vulkan: Failed to allocate pinned memory (%s)\n", e.what()); // fallback to cpu buffer return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); } @@ -8872,7 +9554,7 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor if (ctx->transfer_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); + transfer_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); ctx->transfer_ctx = transfer_ctx; ggml_vk_ctx_begin(ctx->device, transfer_ctx); } else { @@ -8895,7 +9577,7 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_ if (ctx->transfer_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); + transfer_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); ctx->transfer_ctx = transfer_ctx; ggml_vk_ctx_begin(ctx->device, transfer_ctx); } else { @@ -8918,7 +9600,7 @@ static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_ if (ctx->transfer_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->device->transfer_queue); + transfer_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); ctx->transfer_ctx = transfer_ctx; ggml_vk_ctx_begin(ctx->device, transfer_ctx); } else { @@ -8979,7 +9661,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ggml_vk_load_shaders(ctx->device); } ggml_vk_preallocate_buffers(ctx); - ggml_pipeline_allocate_descriptor_sets(ctx->device); + ggml_pipeline_allocate_descriptor_sets(ctx); int last_node = cgraph->n_nodes - 1; @@ -8994,6 +9676,29 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg bool first_node_in_batch = true; // true if next node will be first node in a batch int submit_node_idx = 0; // index to first node in a batch + vk_context compute_ctx; + if (vk_perf_logger_enabled) { + // allocate/resize the query pool + if (ctx->device->num_queries < cgraph->n_nodes + 1) { + if (ctx->device->query_pool) { + ctx->device->device.destroyQueryPool(ctx->device->query_pool); + } + vk::QueryPoolCreateInfo query_create_info; + query_create_info.queryType = vk::QueryType::eTimestamp; + query_create_info.queryCount = cgraph->n_nodes + 100; + ctx->device->query_pool = ctx->device->device.createQueryPool(query_create_info); + ctx->device->num_queries = query_create_info.queryCount; + } + + ctx->device->device.resetQueryPool(ctx->device->query_pool, 0, cgraph->n_nodes+1); + + GGML_ASSERT(ctx->compute_ctx.expired()); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); + compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0); + } + // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution. // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB // (and scaled down based on model size, so smaller models submit earlier). @@ -9021,6 +9726,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit); + if (vk_perf_logger_enabled) { + if (ctx->compute_ctx.expired()) { + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); + } else { + compute_ctx = ctx->compute_ctx.lock(); + } + compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+1); + } + if (enqueued) { ++submitted_nodes; @@ -9042,9 +9758,27 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg } } -#ifdef GGML_VULKAN_PERF - ctx->device->perf_logger->print_timings(); -#endif + if (vk_perf_logger_enabled) { + // End the command buffer and submit/wait + GGML_ASSERT(!ctx->compute_ctx.expired()); + compute_ctx = ctx->compute_ctx.lock(); + ggml_vk_ctx_end(compute_ctx); + + ggml_vk_submit(compute_ctx, ctx->device->fence); + VK_CHECK(ctx->device->device.waitForFences({ ctx->device->fence }, true, UINT64_MAX), "GGML_VULKAN_PERF waitForFences"); + ctx->device->device.resetFences({ ctx->device->fence }); + + // Get the results and pass them to the logger + std::vector timestamps(cgraph->n_nodes + 1); + VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->device->query_pool, 0, cgraph->n_nodes + 1, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), "get timestamp results"); + for (int i = 0; i < cgraph->n_nodes; i++) { + if (!ggml_vk_is_empty(cgraph->nodes[i])) { + ctx->device->perf_logger->log_timing(cgraph->nodes[i], uint64_t((timestamps[i+1] - timestamps[i]) * ctx->device->properties.limits.timestampPeriod)); + } + } + + ctx->device->perf_logger->print_timings(); + } ggml_vk_graph_cleanup(ctx); @@ -9188,7 +9922,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_SIGMOID: - return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous(op->src[0]) && + (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (op->src[0]->type == op->type); default: return false; } @@ -9206,6 +9943,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm switch (src0_type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -9241,19 +9979,23 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (a->ne[3] != b->ne[3]) { return false; } - if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) || + if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16) || !(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) { return false; } + if (op->src[0]->type == GGML_TYPE_BF16 && op->src[1]->type == GGML_TYPE_F16) { + // We currently don't have a bf16 x f16 shader, or an fp16->bf16 copy shader. + // So don't support this combination for now. + return false; + } return true; } break; case GGML_OP_FLASH_ATTN_EXT: { ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - if (!ggml_vk_get_device(ctx->device)->coopmat2) { - return false; - } + auto device = ggml_vk_get_device(ctx->device); + bool coopmat2 = device->coopmat2; switch (op->src[0]->ne[0]) { case 64: case 80: @@ -9286,10 +10028,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm switch (op->src[1]->type) { case GGML_TYPE_F16: case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + // supported in scalar and coopmat2 paths + break; case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently //case GGML_TYPE_Q2_K: //case GGML_TYPE_Q3_K: @@ -9305,10 +10049,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm //case GGML_TYPE_IQ3_S: //case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + // currently supported only in coopmat2 path + if (!coopmat2) { + return false; + } break; default: return false; } + if (!coopmat2 && !device->subgroup_shuffle) { + // scalar FA uses subgroupShuffle + return false; + } return true; } case GGML_OP_GET_ROWS: @@ -9316,6 +10068,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm switch (op->src[0]->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -9346,6 +10099,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm switch (src1_type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -9359,6 +10113,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } if (src1_type == GGML_TYPE_F32) { switch (src0_type) { + case GGML_TYPE_F16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -9374,6 +10129,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { return true; } + + // We can handle copying from a type to the same type if it's + // contiguous (memcpy). We use f16 or f32 shaders to do the copy, + // so the type/block size must be a multiple of 4. + if (src0_type == src1_type && + ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op) && + (ggml_type_size(src0_type) % 2) == 0) { + return true; + } return false; } break; case GGML_OP_REPEAT: @@ -9387,16 +10151,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: + case GGML_OP_RMS_NORM: return true; case GGML_OP_NORM: case GGML_OP_GROUP_NORM: - case GGML_OP_RMS_NORM: case GGML_OP_L2_NORM: return ggml_is_contiguous(op->src[0]); case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); case GGML_OP_SILU_BACK: case GGML_OP_RMS_NORM_BACK: case GGML_OP_SQR: @@ -9420,12 +10187,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_CONV_2D_DW: case GGML_OP_POOL_2D: case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: case GGML_OP_LEAKY_RELU: case GGML_OP_OPT_STEP_ADAMW: return true; + case GGML_OP_CONV_TRANSPOSE_1D: + return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; default: return false; } @@ -9572,8 +10342,9 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) { switch (props.vendorID) { case VK_VENDOR_ID_INTEL: - // Intel drivers don't support coopmat properly yet - return false; + // Only allowing Xe2 GPU at the moment since Xe2 GPU can gain significant performance boost, + // while some older hardware (ex. Arc A770) has performance regressions + return arch == vk_device_architecture::INTEL_XE2; case VK_VENDOR_ID_AMD: if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) { // Workaround for AMD proprietary driver reporting support on all GPUs @@ -9775,7 +10546,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_CONCAT) { tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params); } else if (tensor->op == GGML_OP_UPSCALE) { - tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->op_params[0], tensor->op_params[1], (ggml_scale_mode) tensor->op_params[0]); + tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]); } else if (tensor->op == GGML_OP_SCALE) { const float * params = (const float *)tensor->op_params; tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]); @@ -9917,6 +10688,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { const int32_t dim = tensor->op_params[0]; const int32_t max_period = tensor->op_params[1]; tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period); + } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_1D){ + const int32_t s0 = tensor->op_params[0]; + const int32_t p0 = tensor->op_params[1]; + const int32_t d0 = tensor->op_params[2]; + tensor_clone = ggml_conv_transpose_1d(ggml_ctx, src_clone[0], src_clone[1], s0, p0, d0); } else if (tensor->op == GGML_OP_POOL_2D) { enum ggml_op_pool op = static_cast(tensor->op_params[0]); const int32_t k0 = tensor->op_params[1]; @@ -10064,7 +10840,8 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { ggml_vk_print_graph_origin(tensor, done); GGML_ABORT("fatal error"); } - if (first_error[0] == -1 && std::fabs(correct - result) > 0.1f) { + const double denom = std::fabs(correct) > 1.0f ? (std::fabs(correct) > 1e-8 ? std::fabs(correct) : 1e-8) : 1.0f; + if (first_error[0] == -1 && std::fabs(correct - result) / denom > 0.5) { first_error[0] = i0; first_error[1] = i1; first_error[2] = i2; @@ -10076,7 +10853,7 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { // Special case, value is infinite, avoid NaN result in avg_err // NaN also appears in results, if both are nan error is 0 if (!std::isinf(correct) && !std::isinf(result) && !std::isnan(correct) && !std::isnan(result)) { - avg_err += std::fabs(correct - result); + avg_err += std::fabs(correct - result) / denom; } counter++; } @@ -10111,7 +10888,7 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) { ggml_vk_print_graph_origin(tensor, done); } - if (avg_err > 0.05 || std::isnan(avg_err)) { + if (avg_err > 0.5 || std::isnan(avg_err)) { std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl; std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; if (src0 != nullptr) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt index d6e0b2a5a..14e9daaa0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt @@ -5,13 +5,21 @@ find_package (Threads REQUIRED) if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + message(STATUS "Enabling coopmat glslc support") endif() if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + message(STATUS "Enabling coopmat2 glslc support") endif() if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + message(STATUS "Enabling dot glslc support") endif() +if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + message(STATUS "Enabling bfloat16 glslc support") +endif() + set(TARGET vulkan-shaders-gen) add_executable(${TARGET} vulkan-shaders-gen.cpp) install(TARGETS ${TARGET} RUNTIME) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp b/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp index dd828c232..6567a8c54 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp @@ -18,7 +18,11 @@ void main() { // fast path for when all four iterations are in-bounds if (idx + (num_iter-1)*num_threads < p.ne) { [[unroll]] for (uint i = 0; i < num_iter; ++i) { -#ifndef OPTIMIZATION_ERROR_WORKAROUND + +#if defined(DATA_D_BF16) + float f = float(data_a[get_aoffset() + idx]); + data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f)); +#elif !defined(OPTIMIZATION_ERROR_WORKAROUND) data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]); #else data_d[get_doffset() + idx] = data_a[get_aoffset() + idx]; @@ -31,7 +35,10 @@ void main() { continue; } -#ifndef OPTIMIZATION_ERROR_WORKAROUND +#if defined(DATA_D_BF16) + float f = float(data_a[get_aoffset() + idx]); + data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f)); +#elif !defined(OPTIMIZATION_ERROR_WORKAROUND) data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]); #else data_d[get_doffset() + idx] = data_a[get_aoffset() + idx]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp new file mode 100644 index 000000000..938c74da5 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp @@ -0,0 +1,105 @@ +#version 450 + +#include "types.comp" + +layout (push_constant) uniform parameter +{ + uint ne; + uint batches; + uint channels; + uint dst_w; + uint dst_h; + uint src_w; + uint src_h; + uint knl_w; + uint knl_h; + int stride_x; + int stride_y; + int pad_x; + int pad_y; + int dilation_x; + int dilation_y; +} p; + +layout (binding = 0) readonly buffer A {A_TYPE knl_data[];}; +layout (binding = 1) readonly buffer B {B_TYPE src_data[];}; +layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];}; + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE conv_2d_dw_whcn(uint idx) { + uint i0 = idx / p.dst_w; + uint dst_x = idx - i0 * p.dst_w; + uint i1 = i0 / p.dst_h; + uint dst_y = i0 - i1 * p.dst_h; + uint n = i1 / p.channels; + uint c = i1 - n * p.channels; + + uint src_i = n * p.channels * p.src_h * p.src_w + c * p.src_h * p.src_w; + uint knl_i = c * p.knl_h * p.knl_w; + + FLOAT_TYPE sum = 0.0; + for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) { + uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y; + if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int + continue; + } + for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) { + uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x; + if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int + continue; + } + FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]); + FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]); + sum = fma(v, k, sum); + } + } + return sum; +} + +FLOAT_TYPE conv_2d_dw_cwhn(uint idx) { + uint i0 = idx / p.channels; + uint c = idx - i0 * p.channels; + uint i1 = i0 / p.dst_w; + uint dst_x = i0 - i1 * p.dst_w; + uint n = i1 / p.dst_h; + uint dst_y = i1 - n * p.dst_h; + + uint src_i = n * p.channels * p.src_h * p.src_w; + uint src_row = p.src_w * p.channels; + uint knl_row = p.knl_w * p.channels; + + FLOAT_TYPE sum = 0.0; + for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) { + uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y; + if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int + continue; + } + for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) { + uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x; + if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int + continue; + } + FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]); + FLOAT_TYPE k = FLOAT_TYPE(knl_data[ knl_y * knl_row + knl_x * p.channels + c]); + sum = fma(v, k, sum); + } + } + return sum; +} + +void main() { + uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + if (idx >= p.ne) { + return; + } + + FLOAT_TYPE result = +#ifdef WHCN + conv_2d_dw_whcn(idx); +#else + conv_2d_dw_cwhn(idx); +#endif + dst_data[idx] = D_TYPE(result); +} + diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp new file mode 100644 index 000000000..b17b4e83e --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp @@ -0,0 +1,98 @@ +#version 450 + +#include "types.comp" + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // src0 - kernel: [K, Cout, Cin] +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; // src1 - input: [L, Cin] +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; // dst - result [KL, Cout] + +layout(local_size_x = 128 , local_size_y = 1, local_size_z = 1) in; + +layout (push_constant) uniform parameter { + uint32_t Cout; + uint32_t Cin; + uint32_t K; + uint32_t L; + uint32_t KL; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb11; + uint32_t nb1; + + int32_t s0; +} p; + + +uint32_t Cout_idx = gl_WorkGroupID.x; +const uint32_t bs = gl_WorkGroupSize.x; +uint32_t tid = gl_LocalInvocationID.x; +// Code is more straightforward if we assume it is bs*s0+K instead of (bs-1)*s0+K. +uint32_t tmp_len = bs*p.s0+p.K; +shared D_TYPE tmp[4096]; + +uint splitWork(uint workSize){ + return (bs + workSize -1) / bs; +} + +void main(){ + for(uint32_t i = 0; i < splitWork(tmp_len); i++){ + uint32_t idx = i*bs+tid; + if(idx < tmp_len){ + tmp[idx] = 0.0; + } + } + + uint32_t L_blocks = splitWork(p.L); + for(uint32_t L_block_id = 0; L_block_id < L_blocks; L_block_id++){ + if(L_block_id > 0){ + barrier(); + // Shift values in tmp to the current processing window + for(int i = 0; i < splitWork(tmp_len); i++){ + uint32_t idx = i*bs+tid; + if(idx >= bs*p.s0 && idx < tmp_len){ + tmp[idx-bs*p.s0] = tmp[idx]; + tmp[idx] = 0.0; + }else if(idx >= p.K && idx < bs*p.s0){ + tmp[idx] = 0.0; + } + } + } + barrier(); + + // Save contributions of the block to tmp + uint32_t L_idx = L_block_id*bs + tid; + for(uint32_t K_idx = 0; K_idx < p.K; K_idx++){ + D_TYPE dp = 0.0; + for(uint32_t Cin_idx = 0; Cin_idx < p.Cin; Cin_idx++){ + A_TYPE elemKrn = data_a[K_idx + Cout_idx * p.nb01 + Cin_idx * p.nb02]; + if(L_idx < p.L){ + B_TYPE elemInp = data_b[L_idx + Cin_idx*p.nb11]; + dp = fma(elemKrn, elemInp, dp); + } + } + tmp[tid*p.s0 + K_idx] += dp; + barrier(); + } + + // Save the computed values except the last block that can have different size + uint32_t KLb_idx = L_block_id*bs*p.s0; + if(L_block_id < L_blocks-1){ + for(uint32_t s0_idx = 0; s0_idx < p.s0; s0_idx++){ + uint32_t sh_idx = p.s0*tid+s0_idx; + uint32_t KL_idx = KLb_idx+sh_idx; + if(KL_idx < p.KL){ + data_d[KL_idx + Cout_idx*p.nb1] = tmp[sh_idx]; + } + } + } + } + + for(uint32_t i = 0; i < splitWork(tmp_len); i++){ + uint32_t idx = i*bs+tid; + uint32_t KL_idx = (L_blocks-1)*bs*p.s0+idx; + if(KL_idx < p.KL){ + data_d[KL_idx + Cout_idx*p.nb1] = tmp[idx]; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp index 29c906494..f476a2e3d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp @@ -12,7 +12,10 @@ void main() { return; } -#ifndef OPTIMIZATION_ERROR_WORKAROUND +#if defined(DATA_D_BF16) + float f = float(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f)); +#elif !defined(OPTIMIZATION_ERROR_WORKAROUND) data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]); #else data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp index 2a162a2c8..0d9739d40 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp @@ -23,6 +23,12 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { } #endif +#if defined(DATA_A_BF16) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1])); +} +#endif + #if defined(DATA_A_Q4_0) vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint vui = uint(data_a[a_offset + ib].qs[iqs]); @@ -428,7 +434,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { } #endif -#if defined(DATA_A_F32) || defined(DATA_A_F16) +#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16) vec2 get_dm(uint ib, uint a_offset) { return vec2(0, 0); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp index 962d2353f..9cb7da2da 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -482,7 +482,7 @@ float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCo const uint ib8 = (idx & 0x18) >> 3; // 0..3 const uint iqs = 8 * ib32 + ib8; - const uint8_t qs = bl.block.qs[iqs]; + const uint qs = bl.block.qs[iqs]; const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3])); const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28)); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp index 39184ef58..b604c1881 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp @@ -1,6 +1,6 @@ #version 450 -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #include "dequant_head.comp" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp new file mode 100644 index 000000000..ce230a8f7 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -0,0 +1,337 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#extension GL_KHR_shader_subgroup_shuffle : enable + +#include "types.comp" +#include "flash_attn_base.comp" + +const uint32_t D_per_thread = D / D_split; + +const uint32_t cols_per_iter = WorkGroupSize / D_split; +const uint32_t cols_per_thread = Bc / cols_per_iter; + + +layout (binding = 0) readonly buffer Q {float data_q[];}; +layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; +layout (binding = 1) readonly buffer K {float16_t data_k[];}; +layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; +layout (binding = 2) readonly buffer V {float16_t data_v[];}; +layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; +layout (binding = 3) readonly buffer M {float16_t data_m[];}; + +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + uint32_t offset = (iq2 + r) * D + c; + data_o[o_offset + offset] = D_TYPE(elem); + return elem; +} + +shared FLOAT_TYPE tmpsh[WorkGroupSize]; +shared vec4 tmpshv4[WorkGroupSize]; + +shared float masksh[Bc][Br]; +shared vec4 Qf[Br][D / 4]; + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + init_indices(); + + const uint32_t tid = gl_LocalInvocationIndex; + const uint32_t d_tid = gl_LocalInvocationIndex % D_split; + const uint32_t col_tid = gl_LocalInvocationIndex / D_split; + + uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + + [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (D / 4); + uint32_t r = (idx + tid) / (D / 4); + if (r < Br && d < D / 4 && + i * Br + r < N) { + Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale; + } + } + barrier(); + + vec4 Of[Br][D_per_thread / 4]; + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] = vec4(0.0); + } + } + + float Lf[Br], Mf[Br]; + + // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. + const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); + + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Lf[r] = 0; + Mf[r] = NEG_FLT_MAX_OVER_2; + } + + float slope[Br]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + slope[r] = 1.0; + } + + // ALiBi + if (p.max_bias > 0.0f) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2); + } + } + +#if BLOCK_SIZE > 1 + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; +#else + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; +#endif + + [[dont_unroll]] + for (uint32_t j = start_j; j < end_j; ++j) { + + float Sf[Br][cols_per_thread]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + Sf[r][c] = 0.0; + } + } + + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); +#else + vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); +#endif + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf); + } + } + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + // Compute sum across the D_split + [[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Sf[r][c] += subgroupShuffleXor(Sf[r][c], s); + } + } + } + + if (p.logit_softcap != 0.0f) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]); + } + } + } + + if (p.mask != 0) { + + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) % Bc; + uint32_t r = (idx + tid) / Bc; + if (idx + tid < Bc * Br) { + masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]); + } + } + barrier(); + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float mvf = masksh[c * cols_per_iter + col_tid][r]; + + Sf[r][c] += slope[r]*mvf; + } + } + barrier(); + } + + float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + rowmaxf[r] = Sf[r][0]; + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + rowmaxf[r] = max(rowmaxf[r], Sf[r][c]); + } + Moldf[r] = Mf[r]; + + // M = max(rowmax, Mold) + // P = e^(S - M) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf[r], Moldf[r]); + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + Pf[r][c] = exp(Sf[r][c] - Mf[r]); + } + eMf[r] = exp(Moldf[r] - Mf[r]); + + // Compute sum across row of P + rowsumf[r] = 0.0; + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + rowsumf[r] += Pf[r][c]; + } + + Lf[r] = eMf[r]*Lf[r] + rowsumf[r]; + } + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] = eMf[r] * Of[r][d]; + } + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); +#else + vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); +#endif + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] += Pf[r][c] * Vf; + } + } + } + + barrier(); + } + + // reduce across threads + + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float rowmaxf, eMf; + + tmpsh[tid] = Mf[r]; + // Compute max across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { + if (tid < s) { + tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]); + } + barrier(); + } + rowmaxf = tmpsh[d_tid]; + barrier(); + + float Moldf = Mf[r]; + + // M = max(rowmax, Mold) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf, Moldf); + eMf = exp(Moldf - Mf[r]); + + Lf[r] = eMf*Lf[r]; + + tmpsh[tid] = Lf[r]; + + // Compute sum across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { + if (tid < s) { + tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s]; + } + barrier(); + } + Lf[r] = tmpsh[d_tid]; + barrier(); + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + + Of[r][d] = eMf * Of[r][d]; + tmpshv4[tid] = Of[r][d]; + + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { + if (tid < s) { + Of[r][d] += tmpshv4[tid + s]; + tmpshv4[tid] = Of[r][d]; + } + barrier(); + } + Of[r][d] = tmpshv4[d_tid]; + barrier(); + } + } + + + // If there is split_k, then the split_k resolve shader does the final + // division by L. Store the intermediate O value and per-row m and L values. + if (p.k_num > 1) { + uint32_t o_offset = D * p.ne1 * split_k_index; + + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (r < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); + } + } + } + } + + o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (r < N) { + perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + } + } + + return; + } + + float Lfrcp[Br]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Lfrcp[r] = 1.0 / Lf[r]; + } + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] *= Lfrcp[r]; + } + } + + uint32_t o_offset = iq3*p.ne2*p.ne1; + + if (p.gqa_ratio > 1) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (r < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); + } + } + } + } + } else { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (i * Br + r < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); + } + } + } + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp new file mode 100644 index 000000000..61d90e2d8 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp @@ -0,0 +1,162 @@ + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (constant_id = 0) const uint32_t WorkGroupSize = 128; +layout (constant_id = 1) const uint32_t Br = 1; +layout (constant_id = 2) const uint32_t Bc = 32; +layout (constant_id = 3) const uint32_t D = 32; +layout (constant_id = 4) const uint32_t Clamp = 0; +layout (constant_id = 5) const uint32_t D_split = 16; + + +layout (push_constant) uniform parameter { + uint32_t N; + uint32_t KV; + + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + + uint32_t neq2; + uint32_t neq3; + uint32_t nek2; + uint32_t nek3; + uint32_t nev2; + uint32_t nev3; + uint32_t nem1; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t nb21; + uint32_t nb22; + uint32_t nb23; + uint32_t nb31; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t mask; + uint32_t n_head_log2; + float m0; + float m1; + + uint32_t gqa_ratio; + uint32_t split_kv; + uint32_t k_num; +} p; + +layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; + +#if defined(A_TYPE_PACKED16) +#define BINDING_IDX_K 0 +#define BINDING_IDX_V 1 +layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; +#endif + +#if defined(DATA_A_Q4_0) +#define BLOCK_BYTE_SIZE 18 + +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); +} +#endif + +#if defined(DATA_A_Q8_0) +#define BLOCK_BYTE_SIZE 34 +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; + + return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); +} +#endif + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + + +// Store column zero. This is used to save per-row m and L values for split_k. +ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + if (r < N && c == 0) { + uint32_t offset = iq2 + r; + data_o[o_offset + offset] = D_TYPE(elem); + } + return elem; +} + +// Load the slope matrix, indexed by Q's dimension 2. +ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) +{ + const uint32_t h = iq2 + (r % p.gqa_ratio); + + const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); + const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); + + return ACC_TYPE(pow(base, ACC_TYPE(exph))); +} + +uint32_t i, N, KV, split_k_index, Tr, start_j, end_j, + iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3, + q_stride, k_stride, v_stride, m_stride; + +void init_indices() +{ + N = p.N; + KV = p.KV; + + i = gl_WorkGroupID.x; + split_k_index = 0; + + if (p.k_num > 1) { + i = 0; + split_k_index = gl_WorkGroupID.x; + } + + Tr = CEIL_DIV(N, Br); + + start_j = split_k_index * p.split_kv / Bc; + end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); + + // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. + // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. + iq2 = gl_WorkGroupID.y * p.gqa_ratio; + iq3 = gl_WorkGroupID.z; + + // broadcast factors + rk2 = p.neq2/p.nek2; + rk3 = p.neq3/p.nek3; + + rv2 = p.neq2/p.nev2; + rv3 = p.neq3/p.nev3; + + // k indices + ik3 = iq3 / rk3; + ik2 = iq2 / rk2; + + // v indices + iv3 = iq3 / rv3; + iv2 = iq2 / rv2; + + // nb?1 are already divided by the type size and are in units of elements. + // When using grouped query attention, Q is indexed by iq2, so the stride + // should be nb02 (which is in bytes). + q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; + k_stride = p.nb11; + v_stride = p.nb21; + // When using grouped query attention, all rows use the same mask (stride 0). + // "p.gqa_ratio >> 16" is just a roundabout way of writing zero + // that prevents the compiler from folding the "&" through the select + // and breaking the alignment detection. + m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp new file mode 100644 index 000000000..da478be24 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -0,0 +1,360 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_cooperative_matrix : enable + +#include "types.comp" +#include "flash_attn_base.comp" + +const uint32_t D_per_thread = D / D_split; +const uint32_t row_split = 4; +const uint32_t rows_per_thread = Br / row_split; +const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split; +const uint32_t cols_per_thread = Bc / cols_per_iter; + + +layout (binding = 0) readonly buffer Q {float data_q[];}; +layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; +layout (binding = 1) readonly buffer K {float16_t data_k[];}; +layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; +layout (binding = 2) readonly buffer V {float16_t data_v[];}; +layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; +layout (binding = 3) readonly buffer M {float16_t data_m[];}; + +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + uint32_t offset = (iq2 + r) * D + c; + data_o[o_offset + offset] = D_TYPE(elem); + return elem; +} + +// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd +const uint32_t MatBr = 16; +const uint32_t MatBc = 16; + +shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; +shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x]; + +const uint32_t qstride = D / 4 + 2; // in units of f16vec4 +shared f16vec4 Qf[Br * qstride]; + +// Avoid padding for D==256 to make it fit in 48KB shmem. +const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br; +shared ACC_TYPE sfsh[Bc * sfshstride]; + +const uint32_t kshstride = D / 4 + 2; // in units of f16vec4 +shared f16vec4 ksh[Bc * kshstride]; + +shared float slope[Br]; + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + init_indices(); + + const uint32_t tid = gl_LocalInvocationIndex; + + const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split; + const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup; + const uint32_t d_tid = gl_LocalInvocationIndex % D_split; + const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split; + +#define tile_row(r) (row_tid * rows_per_thread + (r)) + + uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + + [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (D / 4); + uint32_t r = (idx + tid) / (D / 4); + if (r < Br && d < D / 4 && + i * Br + r < N) { + Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); + } + } + barrier(); + + ACC_TYPEV4 Of[rows_per_thread][D_per_thread / 4]; + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] = ACC_TYPEV4(0.0); + } + } + + float Lf[rows_per_thread], Mf[rows_per_thread]; + + // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. + const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Lf[r] = 0; + Mf[r] = NEG_FLT_MAX_OVER_2; + } + + // ALiBi + if (p.max_bias > 0.0f) { + if (tid < Br) { + uint r = tid; + slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2); + } + barrier(); + } else { + if (tid < Br) { + uint r = tid; + slope[r] = 1.0; + } + barrier(); + } + +#if BLOCK_SIZE > 1 + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; +#else + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; +#endif + + [[dont_unroll]] + for (uint32_t j = start_j; j < end_j; ++j) { + + [[unroll]] for (uint32_t idx = 0; idx < Bc * D / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (D / 4); + uint32_t c = (idx + tid) / (D / 4); + if (c < Bc && d < D / 4) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + f16vec4 K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); +#else + f16vec4 K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); +#endif + + ksh[c * kshstride + d] = K_Tf; + } + } + barrier(); + + // K * Q^T -> S^T: Bc x D * D x Br -> Bc x Br + // Bc split across workgroup (four subgroups), loop over D in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16 + // This is written transposed in order to allow for N being 8 if implementations need it + coopmat SfMat = coopmat(0); + coopmat KMat; + coopmat QMat; + + for (uint32_t d = 0; d < D / 16; ++d) { + coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor); + + uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4; + coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor); + + SfMat = coopMatMulAdd(KMat, QMat, SfMat); + } + + uint coord = gl_SubgroupID * MatBc * sfshstride; + coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor); + barrier(); + + if (p.logit_softcap != 0.0f) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) / Br; + uint32_t r = (idx + tid) % Br; + if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { + sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r])); + } + } + barrier(); + } + + if (p.mask != 0) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) % Bc; + uint32_t r = (idx + tid) / Bc; + if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { + sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)])); + } + } + barrier(); + } + + float eMf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + float rowmaxf = sfsh[tile_row(r) + (0 * cols_per_iter + col_tid) * sfshstride]; + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride])); + } + float Moldf = Mf[r]; + + // M = max(rowmax, Mold) + // P = e^(S - M) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf, Moldf); + eMf[r] = exp(Moldf - Mf[r]); + } + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] = float16_t(eMf[r]) * Of[r][d]; + } + } + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Lf[r] = eMf[r]*Lf[r]; + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + float Pf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]); + Lf[r] += Pf[r]; + } + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); +#else + vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); +#endif + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] += float16_t(Pf[r]) * ACC_TYPEV4(Vf); + } + } + } + + barrier(); + } + + // reduce across threads + + float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + FLOAT_TYPE M = Mf[r]; + tmpsh[tid] = M; + // Compute max across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { + M = max(M, tmpsh[tid ^ s]); + barrier(); + tmpsh[tid] = M; + barrier(); + } + rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup]; + barrier(); + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Moldf[r] = Mf[r]; + + // M = max(rowmax, Mold) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf[r], Moldf[r]); + eMf[r] = exp(Moldf[r] - Mf[r]); + + Lf[r] = eMf[r]*Lf[r]; + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + FLOAT_TYPE L = Lf[r]; + tmpsh[tid] = L; + // Compute sum across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { + L += tmpsh[tid ^ s]; + barrier(); + tmpsh[tid] = L; + barrier(); + } + Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup]; + barrier(); + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + + Of[r][d] = float16_t(eMf[r]) * Of[r][d]; + tmpshv4[tid] = Of[r][d]; + + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { + Of[r][d] += tmpshv4[tid ^ s]; + barrier(); + tmpshv4[tid] = Of[r][d]; + barrier(); + } + Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup]; + barrier(); + } + } + + // If there is split_k, then the split_k resolve shader does the final + // division by L. Store the intermediate O value and per-row m and L values. + if (p.k_num > 1) { + uint32_t o_offset = D * p.ne1 * split_k_index; + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); + } + } + } + } + + o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + } + } + + return; + } + + float Lfrcp[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Lfrcp[r] = 1.0 / Lf[r]; + } + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] *= float16_t(Lfrcp[r]); + } + } + + uint32_t o_offset = iq3*p.ne2*p.ne1; + + if (p.gqa_ratio > 1) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); + } + } + } + } + } else { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (i * Br + tile_row(r) < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + data_o[o_offset + iq2 * D + (i * Br + tile_row(r)) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); + } + } + } + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index e1baa85f9..6acf67a03 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -18,62 +18,12 @@ #include "types.comp" #include "dequant_funcs_cm2.comp" - -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (constant_id = 1) const uint32_t Br = 32; -layout (constant_id = 2) const uint32_t Bc = 32; -layout (constant_id = 3) const uint32_t D = 32; -layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV; - -layout (push_constant) uniform parameter { - uint32_t N; - uint32_t KV; - - uint32_t ne1; - uint32_t ne2; - uint32_t ne3; - - uint32_t neq2; - uint32_t neq3; - uint32_t nek2; - uint32_t nek3; - uint32_t nev2; - uint32_t nev3; - uint32_t nem1; - - uint32_t nb01; - uint32_t nb02; - uint32_t nb03; - uint32_t nb11; - uint32_t nb12; - uint32_t nb13; - uint32_t nb21; - uint32_t nb22; - uint32_t nb23; - uint32_t nb31; - - float scale; - float max_bias; - float logit_softcap; - - uint32_t mask; - uint32_t n_head_log2; - float m0; - float m1; - - uint32_t gqa_ratio; - uint32_t split_kv; - uint32_t k_num; -} p; +#include "flash_attn_base.comp" layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; layout (binding = 1) readonly buffer K {uint8_t data_k[];}; layout (binding = 2) readonly buffer V {uint8_t data_v[];}; layout (binding = 3) readonly buffer M {uint8_t data_m[];}; -layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; - -#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) { return max(x, y); @@ -118,67 +68,12 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY return elem; } -// Store column zero. This is used to save per-row m and L values for split_k. -ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) -{ - if (r < N && c == 0) { - uint32_t offset = iq2 + r; - data_o[o_offset + offset] = D_TYPE(elem); - } - return elem; -} - -// Load the slope matrix, indexed by Q's dimension 2. -ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) -{ - const uint32_t h = iq2 + (r & (p.gqa_ratio - 1)); - - const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); - const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); - - return ACC_TYPE(pow(base, ACC_TYPE(exph))); -} - void main() { #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); #endif - const uint32_t N = p.N; - const uint32_t KV = p.KV; - - uint32_t i = gl_WorkGroupID.x; - uint32_t split_k_index = 0; - - if (p.k_num > 1) { - i = 0; - split_k_index = gl_WorkGroupID.x; - } - - const uint32_t Tr = CEIL_DIV(N, Br); - - const uint32_t start_j = split_k_index * p.split_kv / Bc; - const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); - - // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. - // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. - const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio; - const uint32_t iq3 = gl_WorkGroupID.z; - - // broadcast factors - const uint32_t rk2 = p.neq2/p.nek2; - const uint32_t rk3 = p.neq3/p.nek3; - - const uint32_t rv2 = p.neq2/p.nev2; - const uint32_t rv3 = p.neq3/p.nev3; - - // k indices - const uint32_t ik3 = iq3 / rk3; - const uint32_t ik2 = iq2 / rk2; - - // v indices - const uint32_t iv3 = iq3 / rv3; - const uint32_t iv2 = iq2 / rv2; + init_indices(); tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp); @@ -195,17 +90,6 @@ void main() { tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); - // nb?1 are already divided by the type size and are in units of elements. - // When using grouped query attention, Q is indexed by iq2, so the stride - // should be nb02 (which is in bytes). - uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; - uint32_t k_stride = p.nb11; - uint32_t v_stride = p.nb21; - // When using grouped query attention, all rows use the same mask (stride 0). - // "p.gqa_ratio >> 16" is just a roundabout way of writing zero - // that prevents the compiler from folding the "&" through the select - // and breaking the alignment detection. - uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; // hint to the compiler that strides are aligned for the aligned variant of the shader if (Clamp != gl_CooperativeMatrixClampModeConstantNV) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp index e877ed779..ee6b86a18 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp @@ -20,9 +20,14 @@ void main() { const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03; const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23; -#ifndef OPTIMIZATION_ERROR_WORKAROUND - data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]); +#if defined(DATA_A_BF16) + FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00])); #else - data_d[d_offset + i00] = data_a[a_offset + i00]; + FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]); +#endif +#ifndef OPTIMIZATION_ERROR_WORKAROUND + data_d[d_offset + i00] = D_TYPE(v); +#else + data_d[d_offset + i00] = D_TYPE(v); #endif } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index 775b48cd0..bb429dd59 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -6,7 +6,7 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -#if !defined(DATA_A_F32) && !defined(DATA_A_F16) +#if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16) #define K_PER_ITER 8 #else #define K_PER_ITER 2 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp index 48376637f..bc633369f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp @@ -21,7 +21,9 @@ layout (push_constant) uniform parameter uint nrows_x; uint row_stride_x; uint channel_stride_x; + uint channel_stride_y; uint channel_x_divisor; + uint ne12; uint b_offset; uint d_offset; } p; @@ -33,6 +35,7 @@ void main() { const uint row_x = gl_GlobalInvocationID.y; const uint channel = gl_GlobalInvocationID.z; const uint channel_x = channel / p.channel_x_divisor; + const uint channel_y = channel % p.ne12; const uint nrows_y = p.ncols_x; const uint nrows_dst = p.nrows_x; @@ -56,7 +59,7 @@ void main() { const uint row_y = col_x; const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; - const uint iy = channel*nrows_y + row_y; + const uint iy = channel_y*p.channel_stride_y + row_y; const vec4 av4 = vec4(data_a_v4[ix / 4]); const vec4 bv4 = vec4(data_b_v4[iy / 4]); @@ -72,7 +75,7 @@ void main() { const uint row_y = col_x; const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; - const uint iy = channel*nrows_y + row_y; + const uint iy = channel_y*p.channel_stride_y + row_y; const vec4 av4 = vec4(data_a_v4[ix / 4]); const vec4 bv4 = vec4(data_b_v4[iy / 4]); @@ -89,7 +92,7 @@ void main() { const uint row_y = col_x; const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; - const uint iy = channel*nrows_y + row_y; + const uint iy = channel_y*p.channel_stride_y + row_y; const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 23ce8ceec..26163b167 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -7,7 +7,11 @@ #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #endif #if defined(DATA_A_IQ1_M) -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#endif + +#if defined(DATA_A_BF16) && defined(COOPMAT) +#extension GL_EXT_bfloat16 : enable #endif #ifdef COOPMAT @@ -29,6 +33,10 @@ #define LOAD_VEC_B 1 #endif +#if !defined(TO_FLOAT_TYPE) +#define TO_FLOAT_TYPE FLOAT_TYPE +#endif + layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; @@ -95,7 +103,7 @@ shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE]; shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE]; #ifdef MUL_MAT_ID -shared u16vec2 row_ids[3072]; +shared u16vec2 row_ids[4096]; #endif // MUL_MAT_ID #define NUM_WARPS (BLOCK_SIZE / WARP) @@ -202,8 +210,8 @@ void main() { #endif #ifdef COOPMAT - coopmat cache_a; - coopmat cache_b; + coopmat cache_a; + coopmat cache_b; coopmat sums[cms_per_row * cms_per_col]; [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { @@ -248,6 +256,21 @@ void main() { buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f); } #endif +#elif defined(DATA_A_BF16) +#if LOAD_VEC_A == 4 + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + buf_a[buf_idx ] = TO_FLOAT_TYPE(data_a[idx].x); + buf_a[buf_idx + 1] = TO_FLOAT_TYPE(data_a[idx].y); + buf_a[buf_idx + 2] = TO_FLOAT_TYPE(data_a[idx].z); + buf_a[buf_idx + 3] = TO_FLOAT_TYPE(data_a[idx].w); +#else + if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { + buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); + } else { + buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(uint16_t(0)); + } +#endif #elif defined(DATA_A_Q4_0) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a; @@ -695,13 +718,13 @@ void main() { const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; #endif const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; - buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x); - buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y); - buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z); - buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w); + buf_b[buf_idx + 0] = TO_FLOAT_TYPE(data_b[idx].x); + buf_b[buf_idx + 1] = TO_FLOAT_TYPE(data_b[idx].y); + buf_b[buf_idx + 2] = TO_FLOAT_TYPE(data_b[idx].z); + buf_b[buf_idx + 3] = TO_FLOAT_TYPE(data_b[idx].w); #elif !MUL_MAT_ID if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) { - buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); + buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); } else { buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); } @@ -709,7 +732,7 @@ void main() { const uint row_i = ic * BN + loadc_b + l; if (row_i < _ne1) { const u16vec2 row_idx = row_ids[row_i]; - buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); + buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); } else { buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 06b7ab09e..918465757 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -14,6 +14,9 @@ #extension GL_EXT_buffer_reference : enable #extension GL_KHR_shader_subgroup_ballot : enable #extension GL_KHR_shader_subgroup_vote : enable +#ifdef DATA_A_BF16 +#extension GL_EXT_bfloat16 : enable +#endif #include "types.comp" @@ -80,10 +83,16 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #define store_scales(a) #endif +#if defined(DATA_A_BF16) +#define MAT_TYPE bfloat16_t +#else +#define MAT_TYPE FLOAT_TYPE +#endif + #ifdef MUL_MAT_ID layout (binding = 3) readonly buffer IDS {int data_ids[];}; -shared u16vec4 row_ids[3072]; +shared u16vec4 row_ids[4096]; layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB { B_TYPE b[]; @@ -271,8 +280,8 @@ void main() { // Manually partial unroll [[unroll]] for (uint j = 0; j < unroll_count; ++j) { - coopmat mat_a; - coopmat mat_b; + coopmat mat_a; + coopmat mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose); @@ -286,8 +295,8 @@ void main() { store_scales(tid); } while (block_k < end_k) { - coopmat mat_a; - coopmat mat_b; + coopmat mat_a; + coopmat mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose); @@ -310,8 +319,8 @@ void main() { // Manually partial unroll [[unroll]] for (uint j = 0; j < unroll_count; ++j) { - coopmat mat_a; - coopmat mat_b; + coopmat mat_a; + coopmat mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose); @@ -325,8 +334,8 @@ void main() { store_scales(tid); } while (block_k < end_k) { - coopmat mat_a; - coopmat mat_b; + coopmat mat_a; + coopmat mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose); @@ -350,8 +359,8 @@ void main() { // Manually partial unroll [[unroll]] for (uint j = 0; j < unroll_count; ++j) { - coopmat mat_a; - coopmat mat_b; + coopmat mat_a; + coopmat mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); @@ -365,8 +374,8 @@ void main() { store_scales(tid); } while (block_k < end_k) { - coopmat mat_a; - coopmat mat_b; + coopmat mat_a; + coopmat mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); @@ -405,8 +414,8 @@ void main() { fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false); } - coopmat mat_a; - coopmat mat_b; + coopmat mat_a; + coopmat mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); #ifdef MUL_MAT_ID diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index 284a35caa..83de90eb7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -101,7 +101,7 @@ shared FLOAT_TYPE_VEC2 buf_b_ds[BN]; #define LOAD_VEC_B 4 #ifdef MUL_MAT_ID -shared u16vec2 row_ids[3072]; +shared u16vec2 row_ids[4096]; #endif // MUL_MAT_ID #define NUM_WARPS (BLOCK_SIZE / WARP) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp index 52a19b62a..4f806270c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp @@ -17,5 +17,5 @@ void main() { return; } - data_d[i] = max(float(data_a[i]), 0); + data_d[i] = D_TYPE(max(float(data_a[i]), 0)); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp index b554400ba..deb8ee996 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -1,6 +1,6 @@ #version 450 -#include "generic_head.comp" +#include "generic_unary_head.comp" #include "types.comp" #extension GL_EXT_control_flow_attributes : enable @@ -8,19 +8,29 @@ layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - shared FLOAT_TYPE sum[BLOCK_SIZE]; void main() { - const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; - const uint tid = gl_LocalInvocationID.x; + const uint ncols = p.ne00; + const uint nrows = gl_NumWorkGroups.x; + const uint nchannels = gl_NumWorkGroups.y; + + const uint row = gl_WorkGroupID.x; + const uint channel = gl_WorkGroupID.y; + const uint samp = gl_WorkGroupID.z; + const uint tid = gl_LocalInvocationID.x; + + const uint stride_row = p.nb01; + const uint stride_channel = p.nb02; + const uint stride_sample = p.nb03; + + uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset(); + uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset(); sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]); + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]); sum[tid] += xi * xi; } @@ -33,10 +43,10 @@ void main() { barrier(); } - const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(p.KX); + const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols); const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col])); + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp index 776581e2c..5c9e5c350 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp @@ -16,5 +16,5 @@ void main() { if (i >= p.KX) { return; } - data_d[i] = D_TYPE(1. / (1 + exp(-1. *data_a[i]))); + data_d[i] = D_TYPE(1. / (1 + exp(-1. * float(data_a[i])))); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp b/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp index 495f966bd..8a6f868f5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp @@ -16,5 +16,5 @@ void main() { if (i >= p.KX) { return; } - data_d[i] = D_TYPE(1. - 2. / (exp(2.*data_a[i]) + 1.)); + data_d[i] = D_TYPE(1. - 2. / (exp(2.*float(data_a[i])) + 1.)); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp new file mode 100644 index 000000000..fd0ba401f --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp @@ -0,0 +1,7 @@ +#version 460 + +#extension GL_EXT_bfloat16 : require + +void main() +{ +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp index f5b29bfb1..3bde71783 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp @@ -33,6 +33,19 @@ #endif #endif +#if defined(DATA_A_BF16) +#define QUANT_K 1 +#define QUANT_R 1 + +#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 +#define A_TYPE uint16_t +#elif LOAD_VEC_A == 4 +#define A_TYPE u16vec4 +#elif LOAD_VEC_A == 8 +#error unsupported +#endif +#endif + #define QUANT_K_Q4_0 32 #define QUANT_R_Q4_0 2 @@ -1343,4 +1356,18 @@ void init_iq_shmem(uvec3 wgsize) } #endif +// returns the bfloat value in the low 16b. +// See ggml_compute_fp32_to_bf16 +uint32_t fp32_to_bf16(float f) +{ + uint32_t u = floatBitsToUint(f); + u = (u + (0x7fff + ((u >> 16) & 1))) >> 16; + return u; +} + +float bf16_to_fp32(uint32_t u) +{ + return uintBitsToFloat(u << 16); +} + #endif // !defined(GGML_TYPES_COMP) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index cf74625cc..c63345ec8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -63,7 +63,8 @@ const std::vector type_names = { "iq3_xxs", "iq3_s", "iq4_xs", - "iq4_nl" + "iq4_nl", + "bf16", }; namespace { @@ -214,7 +215,7 @@ static std::mutex compile_count_mutex; static std::condition_variable compile_count_cond; void string_to_spv_func(const std::string& _name, const std::string& in_fname, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { - std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_coopmat" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); + std::string name = _name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); std::string out_fname = join_paths(output_dir, name + ".spv"); std::string in_path = join_paths(input_dir, in_fname); @@ -296,7 +297,6 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4"; std::map base_dict = { - {"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"}, {"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"}, }; std::string shader_name = "matmul"; @@ -318,12 +318,45 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; - // Shaders with f16 B_TYPE - string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + auto const &FLOAT_TYPE = [&](const std::string &t) -> std::string { + if (t == "bf16") { + // scalar path promotes to float + if (!coopmat && !coopmat2) { + return "float"; + } + return "bfloat16_t"; + } + if (coopmat2 || fp16) { + return "float16_t"; + } + return "float"; + }; - string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + // Shaders with f16 B_TYPE + string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + + string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + + // bf16 + { + std::string load_vec_a_unaligned = "1"; + // For aligned matmul loads + std::string load_vec_a = coopmat2 ? "1" : "4"; + + // scalar path promotes to float + std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32"; + + // If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader +#if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (!(coopmat || coopmat2)) +#endif + { + string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); + } + } for (const auto& tname : type_names) { std::string load_vec_quant = "2"; @@ -332,26 +365,30 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl")) load_vec_quant = "4"; + if (tname == "bf16") { + continue; + } + std::string data_a_key = "DATA_A_" + to_uppercase(tname); // For unaligned, load one at a time for f32/f16, or two at a time for quants - std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : load_vec_quant; + std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant; // For aligned matmul loads - std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : load_vec_quant; + std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant; // don't generate f32 variants for coopmat2 if (!coopmat2) { - string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } if (tname != "f16" && tname != "f32") { - string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) { - string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc); } #endif } @@ -384,16 +421,18 @@ void process_shaders() { #endif } -#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) // flash attention for (const auto& f16acc : {false, true}) { std::string acctype = f16acc ? "float16_t" : "float"; + std::string acctypev4 = f16acc ? "f16vec4" : "vec4"; for (const auto& tname : type_names) { if (tname == "f32") { continue; } + if (tname == "bf16") continue; +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc); @@ -402,9 +441,27 @@ void process_shaders() { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); } +#endif +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", + merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc); + } else if (tname == "q4_0" || tname == "q8_0") { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", + merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); + } +#endif + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", + merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc); + } else if (tname == "q4_0" || tname == "q8_0") { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", + merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); + } } } -#endif for (const auto& tname : type_names) { // mul mat vec @@ -417,12 +474,12 @@ void process_shaders() { string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); // Dequant shaders - if (tname != "f16") { + if (tname != "f16" && tname != "bf16") { string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}})); } if (!string_ends_with(tname, "_k")) { - shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp"; + shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp"; if (tname == "f16") { string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}})); @@ -447,9 +504,13 @@ void process_shaders() { string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); @@ -457,8 +518,26 @@ void process_shaders() { string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); } - string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); + auto get_type_str = [](bool f16) { + return f16 ? "float16_t" : "float"; + }; + auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) { + std::string s; + s += std::string(src0_f16 ? "_f16" : "_f32"); + s += std::string(src1_f16 ? "_f16" : "_f32"); + s += std::string(dst_f16 ? "_f16" : "_f32"); + return s; + }; + for (std::string op : {"add", "sub", "mul", "div"}) { + for (auto src0_f16 : {false, true}) { + for (auto src1_f16 : {false, true}) { + for (auto dst_f16 : {false, true}) { + auto name = op + get_suffix(src0_f16, src1_f16, dst_f16); + string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}}); + } + } + } + } string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); @@ -493,14 +572,21 @@ void process_shaders() { string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("gelu_quick_f16", "gelu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("silu_f16", "silu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("relu_f16", "relu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("tanh_f16", "tanh.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); @@ -536,6 +622,8 @@ void process_shaders() { string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); @@ -544,6 +632,9 @@ void process_shaders() { string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); + string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); + for (auto &c : compiles) { c.wait(); } @@ -598,7 +689,12 @@ void write_output_files() { std::remove(path.c_str()); } } - + for (const char *op : {"add", "sub", "mul", "div"}) { + fprintf(hdr, "extern unsigned char *%s_data[2][2][2];\n", op); + fprintf(hdr, "extern uint64_t %s_len[2][2][2];\n", op); + fprintf(src, "unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n", op, op, op, op, op, op, op, op, op); + fprintf(src, "uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n", op, op, op, op, op, op, op, op, op); + } fclose(hdr); fclose(src); } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 950772c75..a8edad377 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4,6 +4,7 @@ #include "ggml-backend.h" #include "ggml-impl.h" #include "ggml-threading.h" +#include "ggml-cpu.h" #include "ggml.h" // FIXME: required here for quantization functions @@ -63,12 +64,17 @@ // precomputed f32 table for f16 (256 KB) (ggml-impl.h) float ggml_table_f32_f16[1 << 16]; -#if (defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)) && \ - (!defined(TARGET_OS_TV) && !defined(TARGET_OS_WATCH)) +#if defined(__linux__) || \ + defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \ + (defined(__APPLE__) && !TARGET_OS_TV && !TARGET_OS_WATCH) + #include #include #include #include +#if defined(__linux__) +#include +#endif #if defined(__ANDROID__) #include @@ -127,15 +133,46 @@ static void ggml_print_backtrace_symbols(void) { } #endif -static void ggml_print_backtrace(void) { +void ggml_print_backtrace(void) { const char * GGML_NO_BACKTRACE = getenv("GGML_NO_BACKTRACE"); if (GGML_NO_BACKTRACE) { return; } - char attach[32]; - snprintf(attach, sizeof(attach), "attach %d", getpid()); - int pid = fork(); - if (pid == 0) { +#if defined(__linux__) + FILE * f = fopen("/proc/self/status", "r"); + size_t size = 0; + char * line = NULL; + ssize_t length = 0; + while ((length = getline(&line, &size, f)) > 0) { + if (!strncmp(line, "TracerPid:", sizeof("TracerPid:") - 1) && + (length != sizeof("TracerPid:\t0\n") - 1 || line[length - 2] != '0')) { + // Already being debugged, and the breakpoint is the later abort() + free(line); + fclose(f); + return; + } + } + free(line); + fclose(f); + int lock[2] = { -1, -1 }; + (void) !pipe(lock); // Don't start gdb until after PR_SET_PTRACER +#endif + const int parent_pid = getpid(); + const int child_pid = fork(); + if (child_pid < 0) { // error +#if defined(__linux__) + close(lock[1]); + close(lock[0]); +#endif + return; + } else if (child_pid == 0) { // child + char attach[32]; + snprintf(attach, sizeof(attach), "attach %d", parent_pid); +#if defined(__linux__) + close(lock[1]); + (void) !read(lock[0], lock, 1); + close(lock[0]); +#endif // try gdb execlp("gdb", "gdb", "--batch", "-ex", "set style enabled on", @@ -148,22 +185,22 @@ static void ggml_print_backtrace(void) { execlp("lldb", "lldb", "--batch", "-o", "bt", "-o", "quit", - "-p", attach, + "-p", &attach[sizeof("attach ") - 1], (char *) NULL); - exit(EXIT_FAILURE); - } else { - int wstatus; - waitpid(pid, &wstatus, 0); - if (WIFEXITED(wstatus)) { - if (WEXITSTATUS(wstatus) == EXIT_FAILURE) { - // gdb failed, fallback to backtrace_symbols - ggml_print_backtrace_symbols(); - } - } + // gdb failed, fallback to backtrace_symbols + ggml_print_backtrace_symbols(); + _Exit(0); + } else { // parent +#if defined(__linux__) + prctl(PR_SET_PTRACER, child_pid); + close(lock[1]); + close(lock[0]); +#endif + waitpid(child_pid, NULL, 0); } } #else -static void ggml_print_backtrace(void) { +void ggml_print_backtrace(void) { // platform not supported } #endif @@ -184,6 +221,8 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) { abort(); } +// ggml_print_backtrace is registered with std::set_terminate by ggml.cpp + // // logging // @@ -382,58 +421,16 @@ void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) { } } -// FIXME: these functions must detect the instruction set at runtime, since they are part of the core ggml library -// currently, the ggml_cpu_has_* functions are entirely compile-time void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) { - int64_t i = 0; -#if defined(__F16C__) - //if (ggml_cpu_has_f16c()) { - for (; i + 7 < n; i += 8) { - __m256 x_vec = _mm256_loadu_ps(x + i); - __m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); - _mm_storeu_si128((__m128i *)(y + i), y_vec); - } - for(; i + 3 < n; i += 4) { - __m128 x_vec = _mm_loadu_ps(x + i); - __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); - _mm_storel_epi64((__m128i *)(y + i), y_vec); - } - //} -#endif - for (; i < n; i++) { + int i = 0; + for (; i < n; ++i) { y[i] = GGML_FP32_TO_FP16(x[i]); } } void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) { - int64_t i = 0; -#if defined(__AVX512F__) - //if (ggml_cpu_has_avx512()) { - for (; i + 16 <= n; i += 16) { - _mm512_storeu_ps(y + i, - _mm512_castsi512_ps( - _mm512_slli_epi32( - _mm512_cvtepu16_epi32( - _mm256_loadu_si256( - (const __m256i *)(x + i))), - 16))); - } - //} -#endif -#if defined(__AVX2__) - //if (ggml_cpu_has_avx2()) { - for (; i + 8 <= n; i += 8) { - _mm256_storeu_ps(y + i, - _mm256_castsi256_ps( - _mm256_slli_epi32( - _mm256_cvtepu16_epi32( - _mm_loadu_si128( - (const __m128i *)(x + i))), - 16))); - } - //} -#endif - for (; i < n; i++) { + int i = 0; + for (; i < n; ++i) { y[i] = GGML_BF16_TO_FP32(x[i]); } } @@ -891,12 +888,6 @@ struct ggml_context { struct ggml_object * objects_end; }; -struct ggml_context_container { - bool used; - - struct ggml_context context; -}; - // // data types // @@ -956,6 +947,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CONV_TRANSPOSE_1D", "IM2COL", "IM2COL_BACK", + "CONV_2D_DW", "CONV_TRANSPOSE_2D", "POOL_1D", "POOL_2D", @@ -993,7 +985,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "OPT_STEP_ADAMW", }; -static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); +static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1050,6 +1042,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "conv_transpose_1d(x)", "im2col(x)", "im2col_back(x)", + "conv_2d_dw(x)", "conv_transpose_2d(x)", "pool_1d(x)", "pool_2d(x)", @@ -1087,7 +1080,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "adamw(x)", }; -static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81"); +static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -1107,9 +1100,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "HARDSWISH", "HARDSIGMOID", "EXP", + "GELU_ERF", }; -static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14"); +static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); @@ -1338,12 +1332,23 @@ bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) { return ggml_is_contiguous_n(tensor, 2); } +bool ggml_is_contiguously_allocated(const struct ggml_tensor * tensor) { + return ggml_nbytes(tensor) == ggml_nelements(tensor) * ggml_type_size(tensor->type)/ggml_blck_size(tensor->type); +} + bool ggml_is_permuted(const struct ggml_tensor * tensor) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3]; } +bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) { + return + tensor->nb[0] > tensor->nb[2] && + tensor->nb[1] > tensor->nb[0] && + tensor->nb[2] == ggml_type_size(tensor->type); +} + static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); @@ -2308,6 +2313,26 @@ struct ggml_tensor * ggml_repeat( return result; } +struct ggml_tensor * ggml_repeat_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) { + const bool can_repeat = ggml_is_empty(a) || ( + (ne0 % a->ne[0] == 0) && + (ne1 % a->ne[1] == 0) && + (ne2 % a->ne[2] == 0) && + (ne3 % a->ne[3] == 0) + ); + GGML_ASSERT(can_repeat); + + struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); + + result->op = GGML_OP_REPEAT; + result->src[0] = a; + + return result; +} + // ggml_repeat_back struct ggml_tensor * ggml_repeat_back( @@ -2498,6 +2523,20 @@ struct ggml_tensor * ggml_gelu_inplace( return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU); } +// ggml_gelu_erf + +struct ggml_tensor * ggml_gelu_erf( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_ERF); +} + +struct ggml_tensor * ggml_gelu_erf_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a) { + return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_ERF); +} + // ggml_gelu_quick struct ggml_tensor * ggml_gelu_quick( @@ -2760,11 +2799,11 @@ void ggml_mul_mat_set_prec( c = ggml_mul_mat_id(ctx, as, b, ids); as -> [cols, rows, n_expert] - ids -> [n_experts_used, n_tokens] (i32) b -> [cols, n_expert_used, n_tokens] + ids -> [n_expert_used, n_tokens] (i32) c -> [rows, n_expert_used, n_tokens] - in b, n_experts_used can be broadcasted to match the n_expert_used of ids + in b, n_expert_used can be broadcasted to match the n_expert_used of ids c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids */ @@ -4050,6 +4089,46 @@ struct ggml_tensor * ggml_conv_2d_dw( return result; } +// ggml_conv_2d_dw_direct + +struct ggml_tensor * ggml_conv_2d_dw_direct( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int stride0, + int stride1, + int pad0, + int pad1, + int dilation0, + int dilation1) { + GGML_ASSERT(a->ne[2] == 1); + GGML_ASSERT(a->ne[3] == b->ne[2]); + int64_t ne[4]; + ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], stride0, pad0, dilation0); + ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], stride1, pad1, dilation1); + ne[2] = b->ne[2]; + ne[3] = b->ne[3]; + + struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne); + + if (ggml_is_contiguous_channels(b)) { + // Result will be permuted the same way as input (CWHN order) + const int64_t type_size = ggml_type_size(result->type); + GGML_ASSERT(ggml_blck_size(result->type) == 1); + result->nb[0] = result->ne[2] * type_size; + result->nb[1] = result->ne[0] * result->nb[0]; + result->nb[2] = type_size; + } + + int32_t params[] = { stride0, stride1, pad0, pad1, dilation0, dilation1 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_CONV_2D_DW; + result->src[0] = a; + result->src[1] = b; + return result; +} + // ggml_conv_transpose_2d_p0 static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) { @@ -5487,7 +5566,7 @@ static void ggml_compute_backward( // tensor = src0 * 1 + src1 * 0 if (src0_needs_grads) { // dsrc0 = dtensor * 1 - ggml_add_or_set(ctx, cgraph, isrc0, grad); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad, src0)); } if (src1_needs_grads) { // dsrc1 = dtensor * 0 -> noop @@ -5768,10 +5847,9 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * } void ggml_build_backward_expand( - struct ggml_context * ctx_static, - struct ggml_context * ctx_compute, - struct ggml_cgraph * cgraph, - bool accumulate) { + struct ggml_context * ctx, + struct ggml_cgraph * cgraph, + struct ggml_tensor ** grad_accs) { GGML_ASSERT(cgraph->n_nodes > 0); GGML_ASSERT(cgraph->grads); GGML_ASSERT(cgraph->grad_accs); @@ -5844,21 +5922,24 @@ void ggml_build_backward_expand( GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE); - const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node); - GGML_ASSERT(igrad != GGML_HASHSET_FULL); - GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, igrad)); - if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) { - cgraph->grad_accs[igrad] = ggml_dup_tensor(ctx_static, node); - cgraph->grads[igrad] = cgraph->grad_accs[igrad]; - ggml_format_name(cgraph->grad_accs[igrad], "grad acc for %s", node->name); + const size_t ihash = ggml_hash_find(&cgraph->visited_hash_set, node); + GGML_ASSERT(ihash != GGML_HASHSET_FULL); + GGML_ASSERT(ggml_bitset_get(cgraph->visited_hash_set.used, ihash)); + if (grad_accs && grad_accs[i]) { + cgraph->grad_accs[ihash] = grad_accs[i]; + cgraph->grads[ihash] = cgraph->grad_accs[ihash]; + } else if (node->flags & GGML_TENSOR_FLAG_LOSS) { + // loss tensors always need a gradient accumulator + cgraph->grad_accs[ihash] = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, node->ne); + cgraph->grads[ihash] = cgraph->grad_accs[ihash]; } - grads_needed[igrad] = true; + grads_needed[ihash] = true; } for (int i = n_nodes_f - 1; i >= 0; --i) { // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation // use allocator to automatically make inplace operations - ggml_compute_backward(ctx_compute, cgraph, i, grads_needed); + ggml_compute_backward(ctx, cgraph, i, grads_needed); } free(grads_needed); @@ -6004,8 +6085,8 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) { } } -struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { - struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads != NULL); +struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads) { + struct ggml_cgraph * result = ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads || force_grads); ggml_graph_cpy(cgraph, result); return result; } @@ -6024,6 +6105,9 @@ struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) { } void ggml_graph_reset(struct ggml_cgraph * cgraph) { + if (!cgraph) { + return; + } GGML_ASSERT(cgraph->grads != NULL); for (int i = 0; i < cgraph->n_nodes; i++) { @@ -6333,8 +6417,8 @@ void ggml_set_output(struct ggml_tensor * tensor) { tensor->flags |= GGML_TENSOR_FLAG_OUTPUT; } -void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) { - GGML_UNUSED(ctx); // TODO: remove this parameter +void ggml_set_param(struct ggml_tensor * tensor) { + GGML_ASSERT(tensor->op == GGML_OP_NONE); tensor->flags |= GGML_TENSOR_FLAG_PARAM; } diff --git a/ggml/src/ggml.cpp b/ggml/src/ggml.cpp new file mode 100644 index 000000000..0d388d455 --- /dev/null +++ b/ggml/src/ggml.cpp @@ -0,0 +1,26 @@ +#include "ggml-impl.h" + +#include +#include + +static std::terminate_handler previous_terminate_handler; + +GGML_NORETURN static void ggml_uncaught_exception() { + ggml_print_backtrace(); + if (previous_terminate_handler) { + previous_terminate_handler(); + } + abort(); // unreachable unless previous_terminate_handler was nullptr +} + +static bool ggml_uncaught_exception_init = []{ + const char * GGML_NO_BACKTRACE = getenv("GGML_NO_BACKTRACE"); + if (GGML_NO_BACKTRACE) { + return false; + } + const auto prev{std::get_terminate()}; + GGML_ASSERT(prev != ggml_uncaught_exception); + previous_terminate_handler = prev; + std::set_terminate(ggml_uncaught_exception); + return true; +}(); diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index 381a9c7dc..a0a318a29 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -299,10 +299,10 @@ bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vectorversion)) { - if (ctx->version == 1) { - fprintf(stderr, "%s: GGUFv1 is no longer supported, please use a more up-to-date version\n", __func__); + if (ok && ctx->version == 0) { + GGML_LOG_ERROR("%s: bad GGUF version: %" PRIu32 "\n", __func__, ctx->version); ok = false; } - if (ctx->version > GGUF_VERSION) { - fprintf(stderr, "%s: this GGUF file is version %" PRIu32 " but this software only supports up to version %d\n", + + /* + * bit layout is different when reading non-native endian models. + * assuming that the GGUF version is 3, the non-native endian model + * would read it as 0x30000000. we can use the AND operation against + * the last 4 hexadecimal digits to check if the model is the same + * endianness as the host system. + */ + if (ok && (ctx->version & 0x0000FFFF) == 0x00000000) { + GGML_LOG_ERROR("%s: failed to load model: this GGUF file version %" PRIu32 " is extremely large, is there a mismatch between the host and model endianness?\n", __func__, ctx->version); + ok = false; + } + + if (ok && ctx->version == 1) { + GGML_LOG_ERROR("%s: GGUFv1 is no longer supported, please use a more up-to-date version\n", __func__); + ok = false; + } + if (ok && ctx->version > GGUF_VERSION) { + GGML_LOG_ERROR("%s: this GGUF file is version %" PRIu32 " but this software only supports up to version %d\n", __func__, ctx->version, GGUF_VERSION); ok = false; } @@ -363,7 +380,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par if (ok && gr.read(n_tensors)) { static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing"); if (n_tensors < 0 || n_tensors > int64_t(SIZE_MAX/sizeof(gguf_tensor_info))) { - fprintf(stderr, "%s: number of tensors is %" PRIi64 " but must be in [0, %zu]\n", + GGML_LOG_ERROR("%s: number of tensors is %" PRIi64 " but must be in [0, %zu]\n", __func__, n_tensors, SIZE_MAX/sizeof(gguf_tensor_info)); ok = false; } @@ -374,7 +391,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par if (ok && gr.read(n_kv)) { static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing"); if (n_kv < 0 || n_kv > int64_t(SIZE_MAX/sizeof(gguf_kv))) { - fprintf(stderr, "%s: number of key value pairs is %" PRIi64 " but must be in [0, %zu]\n", + GGML_LOG_ERROR("%s: number of key value pairs is %" PRIi64 " but must be in [0, %zu]\n", __func__, n_kv, SIZE_MAX/sizeof(gguf_kv)); ok = false; } @@ -383,7 +400,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par } if (!ok) { - fprintf(stderr, "%s: failed to read header\n", __func__); + GGML_LOG_ERROR("%s: failed to read header\n", __func__); gguf_free(ctx); return nullptr; } @@ -399,15 +416,15 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par try { ok = ok && gr.read(key); } catch (std::length_error &) { - fprintf(stderr, "%s: encountered length_error while reading key %" PRIi64 "\n", __func__, i); + GGML_LOG_ERROR("%s: encountered length_error while reading key %" PRIi64 "\n", __func__, i); ok = false; } catch (std::bad_alloc &) { - fprintf(stderr, "%s: encountered bad_alloc error while reading key %" PRIi64 "\n", __func__, i); + GGML_LOG_ERROR("%s: encountered bad_alloc error while reading key %" PRIi64 "\n", __func__, i); ok = false; } for (size_t j = 0; ok && j < ctx->kv.size(); ++j) { if (key == ctx->kv[j].key) { - fprintf(stderr, "%s: duplicate key '%s' for tensors %zu and %" PRIi64 " \n", __func__, key.c_str(), j, i); + GGML_LOG_ERROR("%s: duplicate key '%s' for tensors %zu and %" PRIi64 " \n", __func__, key.c_str(), j, i); ok = false; } } @@ -441,14 +458,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par case GGUF_TYPE_ARRAY: default: { - fprintf(stderr, "%s: key '%s' has invalid GGUF type %d\n", __func__, key.c_str(), type); + GGML_LOG_ERROR("%s: key '%s' has invalid GGUF type %d\n", __func__, key.c_str(), type); ok = false; } break; } } if (!ok) { - fprintf(stderr, "%s: failed to read key-value pairs\n", __func__); + GGML_LOG_ERROR("%s: failed to read key-value pairs\n", __func__); gguf_free(ctx); return nullptr; } @@ -458,7 +475,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par ctx->alignment = alignment_idx == -1 ? GGUF_DEFAULT_ALIGNMENT : gguf_get_val_u32(ctx, alignment_idx); if (ctx->alignment == 0 || (ctx->alignment & (ctx->alignment - 1)) != 0) { - fprintf(stderr, "%s: alignment %zu is not a power of 2\n", __func__, ctx->alignment); + GGML_LOG_ERROR("%s: alignment %zu is not a power of 2\n", __func__, ctx->alignment); gguf_free(ctx); return nullptr; } @@ -474,14 +491,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par try { ok = ok && gr.read(name); } catch (std::length_error &) { - fprintf(stderr, "%s: encountered length_error while reading tensor name %" PRIi64 "\n", __func__, i); + GGML_LOG_ERROR("%s: encountered length_error while reading tensor name %" PRIi64 "\n", __func__, i); ok = false; } catch (std::bad_alloc &) { - fprintf(stderr, "%s: encountered bad_alloc error while reading tensor name %" PRIi64 "\n", __func__, i); + GGML_LOG_ERROR("%s: encountered bad_alloc error while reading tensor name %" PRIi64 "\n", __func__, i); ok = false; } if (name.length() >= GGML_MAX_NAME) { - fprintf(stderr, "%s: tensor name %" PRIi64 " is too long: %zu >= %d\n", __func__, i, name.length(), GGML_MAX_NAME); + GGML_LOG_ERROR("%s: tensor name %" PRIi64 " is too long: %zu >= %d\n", __func__, i, name.length(), GGML_MAX_NAME); ok = false; break; } @@ -490,7 +507,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par // make sure there are no duplicate tensor names for (int64_t j = 0; ok && j < i; ++j) { if (strcmp(info.t.name, ctx->info[j].t.name) == 0) { - fprintf(stderr, "%s: duplicate tensor name '%s' for tensors %" PRIi64 " and %" PRIi64 "\n", __func__, info.t.name, j, i); + GGML_LOG_ERROR("%s: duplicate tensor name '%s' for tensors %" PRIi64 " and %" PRIi64 "\n", __func__, info.t.name, j, i); ok = false; break; } @@ -505,7 +522,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par uint32_t n_dims = -1; ok = ok && gr.read(n_dims); if (n_dims > GGML_MAX_DIMS) { - fprintf(stderr, "%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n", + GGML_LOG_ERROR("%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n", __func__, info.t.name, n_dims, GGML_MAX_DIMS); ok = false; break; @@ -518,7 +535,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par // check that all ne are non-negative if (info.t.ne[j] < 0) { - fprintf(stderr, "%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n", + GGML_LOG_ERROR("%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n", __func__, info.t.name, j, info.t.ne[j]); ok = false; break; @@ -530,7 +547,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par (INT64_MAX/info.t.ne[2] <= info.t.ne[0]*info.t.ne[1]) || (INT64_MAX/info.t.ne[3] <= info.t.ne[0]*info.t.ne[1]*info.t.ne[2]))) { - fprintf(stderr, "%s: total number of elements in tensor '%s' with shape " + GGML_LOG_ERROR("%s: total number of elements in tensor '%s' with shape " "(%" PRIi64 ", %" PRIi64 ", %" PRIi64 ", %" PRIi64 ") is >= %" PRIi64 "\n", __func__, info.t.name, info.t.ne[0], info.t.ne[1], info.t.ne[2], info.t.ne[3], INT64_MAX); ok = false; @@ -547,7 +564,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par // check that tensor type is within defined range if (info.t.type < 0 || info.t.type >= GGML_TYPE_COUNT) { - fprintf(stderr, "%s: tensor '%s' has invalid ggml type %d (%s)\n", + GGML_LOG_ERROR("%s: tensor '%s' has invalid ggml type %d (%s)\n", __func__, info.t.name, info.t.type, ggml_type_name(info.t.type)); ok = false; break; @@ -557,7 +574,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par // check that row size is divisible by block size if (blck_size == 0 || info.t.ne[0] % blck_size != 0) { - fprintf(stderr, "%s: tensor '%s' of type %d (%s) has %" PRId64 " elements per row, " + GGML_LOG_ERROR("%s: tensor '%s' of type %d (%s) has %" PRId64 " elements per row, " "not a multiple of block size (%" PRId64 ")\n", __func__, info.t.name, (int) info.t.type, ggml_type_name(info.t.type), info.t.ne[0], blck_size); ok = false; @@ -582,7 +599,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par } if (!ok) { - fprintf(stderr, "%s: failed to read tensor info\n", __func__); + GGML_LOG_ERROR("%s: failed to read tensor info\n", __func__); gguf_free(ctx); return nullptr; } @@ -590,7 +607,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par // we require the data section to be aligned, so take into account any padding if (fseek(file, GGML_PAD(ftell(file), ctx->alignment), SEEK_SET) != 0) { - fprintf(stderr, "%s: failed to seek to beginning of data section\n", __func__); + GGML_LOG_ERROR("%s: failed to seek to beginning of data section\n", __func__); gguf_free(ctx); return nullptr; } @@ -604,9 +621,9 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par for (size_t i = 0; i < ctx->info.size(); ++i) { const gguf_tensor_info & ti = ctx->info[i]; if (ti.offset != ctx->size) { - fprintf(stderr, "%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n", + GGML_LOG_ERROR("%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n", __func__, ti.t.name, ti.offset, ctx->size); - fprintf(stderr, "%s: failed to read tensor data\n", __func__); + GGML_LOG_ERROR("%s: failed to read tensor data\n", __func__); gguf_free(ctx); return nullptr; } @@ -634,7 +651,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par *params.ctx = ggml_init(pdata); if (*params.ctx == nullptr) { - fprintf(stderr, "%s: failed to initialize ggml context for storing tensors\n", __func__); + GGML_LOG_ERROR("%s: failed to initialize ggml context for storing tensors\n", __func__); gguf_free(ctx); return nullptr; } @@ -656,7 +673,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par ok = ok && gr.read(data->data, ctx->size); if (!ok) { - fprintf(stderr, "%s: failed to read tensor data binary blob\n", __func__); + GGML_LOG_ERROR("%s: failed to read tensor data binary blob\n", __func__); ggml_free(ctx_data); *params.ctx = nullptr; gguf_free(ctx); @@ -689,7 +706,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par } if (!ok) { - fprintf(stderr, "%s: failed to create tensors\n", __func__); + GGML_LOG_ERROR("%s: failed to create tensors\n", __func__); ggml_free(ctx_data); *params.ctx = nullptr; gguf_free(ctx); @@ -706,7 +723,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p FILE * file = ggml_fopen(fname, "rb"); if (!file) { - fprintf(stderr, "%s: failed to open GGUF file '%s'\n", __func__, fname); + GGML_LOG_ERROR("%s: failed to open GGUF file '%s'\n", __func__, fname); return nullptr; } @@ -1305,7 +1322,7 @@ bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, boo FILE * file = ggml_fopen(fname, "wb"); if (!file) { - fprintf(stderr, "%s: failed to open file '%s' for writing GGUF data\n", __func__, fname); + GGML_LOG_ERROR("%s: failed to open file '%s' for writing GGUF data\n", __func__, fname); return false; } diff --git a/gguf-py/README.md b/gguf-py/README.md index dd4ab7bde..ca7e09c68 100644 --- a/gguf-py/README.md +++ b/gguf-py/README.md @@ -11,6 +11,11 @@ as an example for its usage. pip install gguf ``` +Optionally, you can install gguf with the extra 'gui' to enable the visual GGUF editor. +```sh +pip install gguf[gui] +``` + ## API Examples/Simple Tools [examples/writer.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/examples/writer.py) — Generates `example.gguf` in the current directory to demonstrate generating a GGUF file. Note that this file cannot be used as a model. @@ -25,6 +30,8 @@ pip install gguf [gguf/scripts/gguf_new_metadata.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_new_metadata.py) — Copies a GGUF file with added/modified/removed metadata values. +[gguf/scripts/gguf_editor_gui.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_editor_gui.py) — Allows for viewing, editing, adding, or removing metadata values within a GGUF file as well as viewing its tensors with a Qt interface. + ## Development Maintainers who participate in development of this package are advised to install it in editable mode: diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c2dbf7b64..064bb0ed9 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -104,6 +104,7 @@ class Keys: EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale" EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm" EXPERT_GATING_FUNC = "{arch}.expert_gating_func" + MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers" POOLING_TYPE = "{arch}.pooling_type" LOGIT_SCALE = "{arch}.logit_scale" DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" @@ -139,6 +140,8 @@ class Keys: REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count" SLIDING_WINDOW = "{arch}.attention.sliding_window" SCALE = "{arch}.attention.scale" + KEY_LENGTH_MLA = "{arch}.attention.key_length_mla" + VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla" class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" @@ -174,6 +177,9 @@ class Keys: EMBEDDING_LENGTH = "{arch}.convnext.embedding_length" BLOCK_COUNT = "{arch}.convnext.block_count" + class Classifier: + OUTPUT_LABELS = "{arch}.classifier.output_labels" + class Tokenizer: MODEL = "tokenizer.ggml.model" PRE = "tokenizer.ggml.pre" @@ -221,6 +227,46 @@ class Keys: CHUNK_SIZE = "imatrix.chunk_size" DATASETS = "imatrix.datasets" + class Clip: + PROJECTOR_TYPE = "clip.projector_type" + HAS_VISION_ENCODER = "clip.has_vision_encoder" + HAS_AUDIO_ENCODER = "clip.has_audio_encoder" + HAS_LLAVA_PROJECTOR = "clip.has_llava_projector" + + class ClipVision: + IMAGE_SIZE = "clip.vision.image_size" + PATCH_SIZE = "clip.vision.patch_size" + EMBEDDING_LENGTH = "clip.vision.embedding_length" + FEED_FORWARD_LENGTH = "clip.vision.feed_forward_length" + PROJECTION_DIM = "clip.vision.projection_dim" + BLOCK_COUNT = "clip.vision.block_count" + IMAGE_MEAN = "clip.vision.image_mean" + IMAGE_STD = "clip.vision.image_std" + SPATIAL_MERGE_SIZE = "clip.vision.spatial_merge_size" + USE_GELU = "clip.use_gelu" + USE_SILU = "clip.use_silu" + N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl + + class Attention: + HEAD_COUNT = "clip.vision.attention.head_count" + LAYERNORM_EPS = "clip.vision.attention.layer_norm_epsilon" + + class Projector: + SCALE_FACTOR = "clip.vision.projector.scale_factor" + + class ClipAudio: + NUM_MEL_BINS = "clip.audio.num_mel_bins" + EMBEDDING_LENGTH = "clip.audio.embedding_length" + FEED_FORWARD_LENGTH = "clip.audio.feed_forward_length" + PROJECTION_DIM = "clip.audio.projection_dim" + BLOCK_COUNT = "clip.audio.block_count" + + class Attention: + HEAD_COUNT = "clip.audio.attention.head_count" + LAYERNORM_EPS = "clip.audio.attention.layer_norm_epsilon" + + class Projector: + STACK_FACTOR = "clip.audio.projector.stack_factor" # # recommended mapping of model tensor names for storage in gguf @@ -231,9 +277,11 @@ class GGUFType: MODEL = "model" ADAPTER = "adapter" IMATRIX = "imatrix" + MMPROJ = "mmproj" # dummy, unused for now class MODEL_ARCH(IntEnum): + MMPROJ = auto() # dummy arch for clip.cpp LLAMA = auto() LLAMA4 = auto() DECI = auto() @@ -248,6 +296,8 @@ class MODEL_ARCH(IntEnum): REFACT = auto() BERT = auto() NOMIC_BERT = auto() + NOMIC_BERT_MOE = auto() + NEO_BERT = auto() JINA_BERT_V2 = auto() BLOOM = auto() STABLELM = auto() @@ -300,6 +350,18 @@ class MODEL_ARCH(IntEnum): WAVTOKENIZER_DEC = auto() PLM = auto() BAILINGMOE = auto() + DOTS1 = auto() + ARCEE = auto() + + +class VISION_PROJECTOR_TYPE(IntEnum): + MLP = auto() + LDP = auto() + LDPV2 = auto() + RESAMPLER = auto() + GLM_EDGE = auto() + MERGER = auto() + GEMMA3 = auto() class MODEL_TENSOR(IntEnum): @@ -389,6 +451,8 @@ class MODEL_TENSOR(IntEnum): ATTN_Q_B = auto() ATTN_KV_A_MQA = auto() ATTN_KV_B = auto() + ATTN_K_B = auto() + ATTN_V_B = auto() ATTN_Q_A_NORM = auto() ATTN_KV_A_NORM = auto() FFN_SUB_NORM = auto() @@ -439,9 +503,68 @@ class MODEL_TENSOR(IntEnum): POSNET_ATTN_K = auto() POSNET_ATTN_V = auto() POSNET_ATTN_OUT = auto() + # vision + V_MMPROJ = auto() + V_MMPROJ_FC = auto() + V_MMPROJ_MLP = auto() + V_MMPROJ_PEG = auto() + V_ENC_EMBD_CLS = auto() + V_ENC_EMBD_PATCH = auto() + V_ENC_EMBD_POS = auto() + V_ENC_INPUT_NORM = auto() + V_ENC_ATTN_Q = auto() + V_ENC_ATTN_Q_NORM = auto() + V_ENC_ATTN_K = auto() + V_ENC_ATTN_K_NORM = auto() + V_ENC_ATTN_V = auto() + V_ENC_ATTN_O = auto() + V_ENC_ATTN_O_NORM = auto() + V_ENC_POST_ATTN_NORM = auto() + V_ENC_FFN_UP = auto() + V_ENC_FFN_GATE = auto() + V_ENC_FFN_DOWN = auto() + V_LAYER_SCALE_1 = auto() + V_LAYER_SCALE_2 = auto() + V_PRE_NORM = auto() + V_POST_NORM = auto() + V_MM_INP_NORM = auto() + V_MM_INP_PROJ = auto() # gemma3 + V_MM_SOFT_EMB_NORM = auto() # gemma3 + V_RESMPL_POS_EMBD_K = auto() # minicpmv + V_RESMPL_ATTN_Q = auto() # minicpmv + V_RESMPL_ATTN_K = auto() # minicpmv + V_RESMPL_ATTN_V = auto() # minicpmv + V_RESMPL_ATTN_OUT = auto() # minicpmv + V_RESMPL_KV = auto() # minicpmv + V_RESMPL_KV_NORM = auto() # minicpmv + V_RESMPL_POST_NORM = auto() # minicpmv + V_RESMPL_Q_NORM = auto() # minicpmv + V_RESMPL_PROJ = auto() # minicpmv + V_RESMPL_QUERY = auto() # minicpmv + V_TOK_EMBD_IMG_BREAK = auto() # pixtral + V_MM_PATCH_MERGER = auto() # mistral small 3.1 + # audio (mtmd) + A_ENC_EMBD_POS = auto() + A_ENC_CONV1D = auto() + A_PRE_NORM = auto() + A_POST_NORM = auto() + A_ENC_ATTN_Q = auto() + A_ENC_ATTN_K = auto() + A_ENC_ATTN_V = auto() + A_ENC_INPUT_NORM = auto() + A_ENC_OUTPUT = auto() + A_ENC_OUTPUT_NORM = auto() + A_ENC_FFN_UP = auto() + A_ENC_FFN_GATE = auto() + A_ENC_FFN_DOWN = auto() + A_MMPROJ = auto() + A_MMPROJ_FC = auto() + A_MM_NORM_PRE = auto() + A_MM_NORM_MID = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { + MODEL_ARCH.MMPROJ: "clip", # dummy arch for clip.cpp MODEL_ARCH.LLAMA: "llama", MODEL_ARCH.LLAMA4: "llama4", MODEL_ARCH.DECI: "deci", @@ -456,6 +579,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.REFACT: "refact", MODEL_ARCH.BERT: "bert", MODEL_ARCH.NOMIC_BERT: "nomic-bert", + MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe", + MODEL_ARCH.NEO_BERT: "neo-bert", MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2", MODEL_ARCH.BLOOM: "bloom", MODEL_ARCH.STABLELM: "stablelm", @@ -508,6 +633,18 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec", MODEL_ARCH.PLM: "plm", MODEL_ARCH.BAILINGMOE: "bailingmoe", + MODEL_ARCH.DOTS1: "dots1", + MODEL_ARCH.ARCEE: "arcee", +} + +VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { + VISION_PROJECTOR_TYPE.MLP: "mlp", + VISION_PROJECTOR_TYPE.LDP: "ldp", + VISION_PROJECTOR_TYPE.LDPV2: "ldpv2", + VISION_PROJECTOR_TYPE.RESAMPLER: "resampler", + VISION_PROJECTOR_TYPE.GLM_EDGE: "adapter", + VISION_PROJECTOR_TYPE.MERGER: "qwen2vl_merger", + VISION_PROJECTOR_TYPE.GEMMA3: "gemma3", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -597,6 +734,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa", MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b", + MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b", + MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b", MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", @@ -647,9 +786,126 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.POSNET_ATTN_K: "posnet.{bid}.attn_k", MODEL_TENSOR.POSNET_ATTN_V: "posnet.{bid}.attn_v", MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output", + # vision + MODEL_TENSOR.V_MMPROJ: "mm.{bid}", + MODEL_TENSOR.V_MMPROJ_FC: "mm.model.fc", + MODEL_TENSOR.V_MMPROJ_MLP: "mm.model.mlp.{bid}", + MODEL_TENSOR.V_MMPROJ_PEG: "mm.model.peg.{bid}", + MODEL_TENSOR.V_ENC_EMBD_CLS: "v.class_embd", + MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.patch_embd", + MODEL_TENSOR.V_ENC_EMBD_POS: "v.position_embd", + MODEL_TENSOR.V_ENC_ATTN_Q: "v.blk.{bid}.attn_q", + MODEL_TENSOR.V_ENC_ATTN_Q_NORM: "v.blk.{bid}.attn_q_norm", + MODEL_TENSOR.V_ENC_ATTN_K: "v.blk.{bid}.attn_k", + MODEL_TENSOR.V_ENC_ATTN_K_NORM: "v.blk.{bid}.attn_k_norm", + MODEL_TENSOR.V_ENC_ATTN_V: "v.blk.{bid}.attn_v", + MODEL_TENSOR.V_ENC_INPUT_NORM: "v.blk.{bid}.ln1", + MODEL_TENSOR.V_ENC_ATTN_O: "v.blk.{bid}.attn_out", + MODEL_TENSOR.V_ENC_ATTN_O_NORM: "v.blk.{bid}.attn_out_norm", + MODEL_TENSOR.V_ENC_POST_ATTN_NORM: "v.blk.{bid}.ln2", + MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up", + MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate", + MODEL_TENSOR.V_ENC_FFN_DOWN: "v.blk.{bid}.ffn_down", + MODEL_TENSOR.V_LAYER_SCALE_1: "v.blk.{bid}.ls1", + MODEL_TENSOR.V_LAYER_SCALE_2: "v.blk.{bid}.ls2", + MODEL_TENSOR.V_PRE_NORM: "v.pre_ln", + MODEL_TENSOR.V_POST_NORM: "v.post_ln", + MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection", + MODEL_TENSOR.V_MM_INP_NORM: "mm.input_norm", + MODEL_TENSOR.V_MM_SOFT_EMB_NORM: "mm.soft_emb_norm", + MODEL_TENSOR.V_RESMPL_POS_EMBD_K: "resampler.pos_embd_k", + MODEL_TENSOR.V_RESMPL_ATTN_Q: "resampler.attn.q", + MODEL_TENSOR.V_RESMPL_ATTN_K: "resampler.attn.k", + MODEL_TENSOR.V_RESMPL_ATTN_V: "resampler.attn.v", + MODEL_TENSOR.V_RESMPL_ATTN_OUT: "resampler.attn.out", + MODEL_TENSOR.V_RESMPL_KV: "resampler.kv", + MODEL_TENSOR.V_RESMPL_KV_NORM: "resampler.ln_kv", + MODEL_TENSOR.V_RESMPL_POST_NORM: "resampler.ln_post", + MODEL_TENSOR.V_RESMPL_Q_NORM: "resampler.ln_q", + MODEL_TENSOR.V_RESMPL_PROJ: "resampler.proj", + MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query", + MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral + MODEL_TENSOR.V_MM_PATCH_MERGER: "mm.patch_merger", # mistral small 3.1 + # audio (mtmd) + MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd", + MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}", + MODEL_TENSOR.A_PRE_NORM: "a.pre_ln", + MODEL_TENSOR.A_POST_NORM: "a.post_ln", + MODEL_TENSOR.A_ENC_ATTN_Q: "a.blk.{bid}.attn_q", + MODEL_TENSOR.A_ENC_ATTN_K: "a.blk.{bid}.attn_k", + MODEL_TENSOR.A_ENC_ATTN_V: "a.blk.{bid}.attn_v", + MODEL_TENSOR.A_ENC_INPUT_NORM: "a.blk.{bid}.ln1", + MODEL_TENSOR.A_ENC_OUTPUT: "a.blk.{bid}.attn_out", + MODEL_TENSOR.A_ENC_OUTPUT_NORM: "a.blk.{bid}.ln2", + MODEL_TENSOR.A_ENC_FFN_UP: "a.blk.{bid}.ffn_up", + MODEL_TENSOR.A_ENC_FFN_GATE: "a.blk.{bid}.ffn_gate", + MODEL_TENSOR.A_ENC_FFN_DOWN: "a.blk.{bid}.ffn_down", + MODEL_TENSOR.A_MMPROJ: "mm.a.mlp.{bid}", + MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc", + MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre", + MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { + MODEL_ARCH.MMPROJ: [ + MODEL_TENSOR.V_MMPROJ, + MODEL_TENSOR.V_MMPROJ_FC, + MODEL_TENSOR.V_MMPROJ_MLP, + MODEL_TENSOR.V_MMPROJ_PEG, + MODEL_TENSOR.V_ENC_EMBD_CLS, + MODEL_TENSOR.V_ENC_EMBD_PATCH, + MODEL_TENSOR.V_ENC_EMBD_POS, + MODEL_TENSOR.V_ENC_INPUT_NORM, + MODEL_TENSOR.V_ENC_ATTN_Q, + MODEL_TENSOR.V_ENC_ATTN_Q_NORM, + MODEL_TENSOR.V_ENC_ATTN_K, + MODEL_TENSOR.V_ENC_ATTN_K_NORM, + MODEL_TENSOR.V_ENC_ATTN_V, + MODEL_TENSOR.V_ENC_ATTN_O, + MODEL_TENSOR.V_ENC_ATTN_O_NORM, + MODEL_TENSOR.V_ENC_POST_ATTN_NORM, + MODEL_TENSOR.V_ENC_FFN_UP, + MODEL_TENSOR.V_ENC_FFN_GATE, + MODEL_TENSOR.V_ENC_FFN_DOWN, + MODEL_TENSOR.V_LAYER_SCALE_1, + MODEL_TENSOR.V_LAYER_SCALE_2, + MODEL_TENSOR.V_PRE_NORM, + MODEL_TENSOR.V_POST_NORM, + MODEL_TENSOR.V_MM_INP_PROJ, + MODEL_TENSOR.V_MM_INP_NORM, + MODEL_TENSOR.V_MM_SOFT_EMB_NORM, + MODEL_TENSOR.V_RESMPL_POS_EMBD_K, + MODEL_TENSOR.V_RESMPL_ATTN_Q, + MODEL_TENSOR.V_RESMPL_ATTN_K, + MODEL_TENSOR.V_RESMPL_ATTN_V, + MODEL_TENSOR.V_RESMPL_ATTN_OUT, + MODEL_TENSOR.V_RESMPL_KV, + MODEL_TENSOR.V_RESMPL_KV_NORM, + MODEL_TENSOR.V_RESMPL_POST_NORM, + MODEL_TENSOR.V_RESMPL_Q_NORM, + MODEL_TENSOR.V_RESMPL_PROJ, + MODEL_TENSOR.V_RESMPL_QUERY, + MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK, + MODEL_TENSOR.V_MM_PATCH_MERGER, + # audio + MODEL_TENSOR.A_ENC_EMBD_POS, + MODEL_TENSOR.A_ENC_CONV1D, + MODEL_TENSOR.A_PRE_NORM, + MODEL_TENSOR.A_POST_NORM, + MODEL_TENSOR.A_ENC_ATTN_Q, + MODEL_TENSOR.A_ENC_ATTN_K, + MODEL_TENSOR.A_ENC_ATTN_V, + MODEL_TENSOR.A_ENC_INPUT_NORM, + MODEL_TENSOR.A_ENC_OUTPUT, + MODEL_TENSOR.A_ENC_OUTPUT_NORM, + MODEL_TENSOR.A_ENC_FFN_UP, + MODEL_TENSOR.A_ENC_FFN_GATE, + MODEL_TENSOR.A_ENC_FFN_DOWN, + MODEL_TENSOR.A_MMPROJ, + MODEL_TENSOR.A_MMPROJ_FC, + MODEL_TENSOR.A_MM_NORM_PRE, + MODEL_TENSOR.A_MM_NORM_MID, + ], MODEL_ARCH.LLAMA: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -792,6 +1048,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.POS_EMBD, MODEL_TENSOR.OUTPUT_NORM, MODEL_TENSOR.ATTN_OUT_NORM, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, MODEL_TENSOR.ATTN_V, @@ -816,6 +1073,34 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_UP, MODEL_TENSOR.LAYER_OUT_NORM, ], + MODEL_ARCH.NOMIC_BERT_MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.TOKEN_TYPES, + MODEL_TENSOR.POS_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_OUT_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.LAYER_OUT_NORM, + ], + MODEL_ARCH.NEO_BERT: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.ENC_OUTPUT_NORM, + MODEL_TENSOR.CLS, + MODEL_TENSOR.CLS_OUT, + ], MODEL_ARCH.JINA_BERT_V2: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD_NORM, @@ -1524,6 +1809,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_Q_B, MODEL_TENSOR.ATTN_KV_A_MQA, MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_K_B, + MODEL_TENSOR.ATTN_V_B, MODEL_TENSOR.ATTN_Q_A_NORM, MODEL_TENSOR.ATTN_KV_A_NORM, MODEL_TENSOR.ATTN_OUT, @@ -1720,6 +2007,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_GATE_EXP, MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, ], MODEL_ARCH.CHAMELEON: [ MODEL_TENSOR.TOKEN_EMBD, @@ -1778,6 +2068,45 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, ], + MODEL_ARCH.DOTS1: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_EXP_PROBS_B, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_UP_SHEXP, + ], + MODEL_ARCH.ARCEE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], # TODO } @@ -1860,6 +2189,8 @@ class PoolingType(IntEnum): NONE = 0 MEAN = 1 CLS = 2 + LAST = 3 + RANK = 4 class GGMLQuantizationType(IntEnum): @@ -1986,6 +2317,19 @@ class GGUFValueType(IntEnum): raise ValueError(f"Unknown type: {type(val)}") +class VisionProjectorType: + GEMMA3 = "gemma3" + IDEFICS3 = "idefics3" + PIXTRAL = "pixtral" + LLAMA4 = "llama4" + QWEN2VL = "qwen2vl_merger" + QWEN25VL = "qwen2.5vl_merger" + ULTRAVOX = "ultravox" + INTERNVL = "internvl" + QWEN2A = "qwen2a" # audio + QWEN25O = "qwen2.5o" # omni + + # Items here are (block size, type size) QK_K = 256 GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = { diff --git a/gguf-py/gguf/gguf_reader.py b/gguf-py/gguf/gguf_reader.py index 5991cdb76..d87e8f723 100644 --- a/gguf-py/gguf/gguf_reader.py +++ b/gguf-py/gguf/gguf_reader.py @@ -251,7 +251,7 @@ class GGUFReader: offs += curr_size return offs - orig_offs, aparts, data_idxs, types # We can't deal with this one. - raise ValueError('Unknown/unhandled field type {gtype}') + raise ValueError(f'Unknown/unhandled field type {gtype}') def _get_tensor_info_field(self, orig_offs: int) -> ReaderField: offs = orig_offs diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 485550aad..54ca0c33f 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -49,6 +49,7 @@ class TensorInfo: class GGUFValue: value: Any type: GGUFValueType + sub_type: GGUFValueType | None = None class WriterState(Enum): @@ -238,7 +239,7 @@ class GGUFWriter: for key, val in kv_data.items(): kv_bytes += self._pack_val(key, GGUFValueType.STRING, add_vtype=False) - kv_bytes += self._pack_val(val.value, val.type, add_vtype=True) + kv_bytes += self._pack_val(val.value, val.type, add_vtype=True, sub_type=val.sub_type) fout.write(kv_bytes) @@ -268,11 +269,11 @@ class GGUFWriter: fout.flush() self.state = WriterState.TI_DATA - def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None: + def add_key_value(self, key: str, val: Any, vtype: GGUFValueType, sub_type: GGUFValueType | None = None) -> None: if any(key in kv_data for kv_data in self.kv_data): - raise ValueError(f'Duplicated key name {key!r}') + logger.warning(f'Duplicated key name {key!r}, overwriting it with new value {val!r} of type {vtype.name}') - self.kv_data[0][key] = GGUFValue(value=val, type=vtype) + self.kv_data[0][key] = GGUFValue(value=val, type=vtype, sub_type=sub_type) def add_uint8(self, key: str, val: int) -> None: self.add_key_value(key,val, GGUFValueType.UINT8) @@ -689,6 +690,12 @@ class GGUFWriter: def add_value_length(self, length: int) -> None: self.add_uint32(Keys.Attention.VALUE_LENGTH.format(arch=self.arch), length) + def add_key_length_mla(self, length: int) -> None: + self.add_uint32(Keys.Attention.KEY_LENGTH_MLA.format(arch=self.arch), length) + + def add_value_length_mla(self, length: int) -> None: + self.add_uint32(Keys.Attention.VALUE_LENGTH_MLA.format(arch=self.arch), length) + def add_max_alibi_bias(self, bias: float) -> None: self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias) @@ -722,6 +729,9 @@ class GGUFWriter: def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None: self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value) + def add_moe_every_n_layers(self, value: int) -> None: + self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value) + def add_swin_norm(self, value: bool) -> None: self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value) @@ -887,7 +897,7 @@ class GGUFWriter: def add_remove_extra_whitespaces(self, value: bool) -> None: self.add_bool(Keys.Tokenizer.REMOVE_EXTRA_WS, value) - def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None: + def add_precompiled_charsmap(self, charsmap: bytes) -> None: self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap) def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None: @@ -925,13 +935,98 @@ class GGUFWriter: def add_eom_token_id(self, id: int) -> None: self.add_uint32(Keys.Tokenizer.EOM_ID, id) + def add_classifier_output_labels(self, labels: Sequence[str]) -> None: + self.add_array(Keys.Classifier.OUTPUT_LABELS.format(arch=self.arch), labels) + + # for vision models + + def add_clip_has_vision_encoder(self, value: bool) -> None: + self.add_bool(Keys.Clip.HAS_VISION_ENCODER, value) + + def add_clip_has_audio_encoder(self, value: bool) -> None: + self.add_bool(Keys.Clip.HAS_AUDIO_ENCODER, value) + + def add_clip_projector_type(self, value: str) -> None: + self.add_string(Keys.Clip.PROJECTOR_TYPE, value) + + def add_vision_projection_dim(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.PROJECTION_DIM, value) + + def add_vision_patch_size(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.PATCH_SIZE, value) + + def add_vision_embedding_length(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.EMBEDDING_LENGTH, value) + + def add_vision_feed_forward_length(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.FEED_FORWARD_LENGTH, value) + + def add_vision_block_count(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.BLOCK_COUNT, value) + + def add_vision_head_count(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT, value) + + def add_vision_attention_layernorm_eps(self, value: float) -> None: + self.add_float32(Keys.ClipVision.Attention.LAYERNORM_EPS, value) + + def add_vision_image_size(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.IMAGE_SIZE, value) + + def add_vision_image_mean(self, values: Sequence[float]) -> None: + self.add_array(Keys.ClipVision.IMAGE_MEAN, values) + + def add_vision_image_std(self, values: Sequence[float]) -> None: + self.add_array(Keys.ClipVision.IMAGE_STD, values) + + def add_vision_spatial_merge_size(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.SPATIAL_MERGE_SIZE, value) + + def add_vision_use_gelu(self, value: bool) -> None: + self.add_bool(Keys.ClipVision.USE_GELU, value) + + def add_vision_use_silu(self, value: bool) -> None: + self.add_bool(Keys.ClipVision.USE_SILU, value) + + def add_vision_projector_scale_factor(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.Projector.SCALE_FACTOR, value) + + def add_vision_n_wa_pattern(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value) + + # audio models + + def add_audio_projection_dim(self, value: int) -> None: + self.add_uint32(Keys.ClipAudio.PROJECTION_DIM, value) + + def add_audio_embedding_length(self, value: int) -> None: + self.add_uint32(Keys.ClipAudio.EMBEDDING_LENGTH, value) + + def add_audio_feed_forward_length(self, value: int) -> None: + self.add_uint32(Keys.ClipAudio.FEED_FORWARD_LENGTH, value) + + def add_audio_block_count(self, value: int) -> None: + self.add_uint32(Keys.ClipAudio.BLOCK_COUNT, value) + + def add_audio_head_count(self, value: int) -> None: + self.add_uint32(Keys.ClipAudio.Attention.HEAD_COUNT, value) + + def add_audio_attention_layernorm_eps(self, value: float) -> None: + self.add_float32(Keys.ClipAudio.Attention.LAYERNORM_EPS, value) + + def add_audio_num_mel_bins(self, value: int) -> None: + self.add_uint32(Keys.ClipAudio.NUM_MEL_BINS, value) + + def add_audio_stack_factor(self, value: int) -> None: + self.add_uint32(Keys.ClipAudio.Projector.STACK_FACTOR, value) + def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes: pack_prefix = '' if not skip_pack_prefix: pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>' return struct.pack(f'{pack_prefix}{fmt}', value) - def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool) -> bytes: + def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool, sub_type: GGUFValueType | None = None) -> bytes: kv_data = bytearray() if add_vtype: @@ -952,7 +1047,9 @@ class GGUFWriter: if len(val) == 0: raise ValueError("Invalid GGUF metadata array. Empty array") - if isinstance(val, bytes): + if sub_type is not None: + ltype = sub_type + elif isinstance(val, bytes): ltype = GGUFValueType.UINT8 else: ltype = GGUFValueType.get_type(val[0]) diff --git a/gguf-py/gguf/scripts/__init__.py b/gguf-py/gguf/scripts/__init__.py deleted file mode 100644 index e77f2e9c9..000000000 --- a/gguf-py/gguf/scripts/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# pyright: reportUnusedImport=false - -from .gguf_convert_endian import main as gguf_convert_endian_entrypoint -from .gguf_dump import main as gguf_dump_entrypoint -from .gguf_set_metadata import main as gguf_set_metadata_entrypoint -from .gguf_new_metadata import main as gguf_new_metadata_entrypoint diff --git a/gguf-py/gguf/scripts/gguf_editor_gui.py b/gguf-py/gguf/scripts/gguf_editor_gui.py new file mode 100755 index 000000000..05f4db0f8 --- /dev/null +++ b/gguf-py/gguf/scripts/gguf_editor_gui.py @@ -0,0 +1,1621 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import logging +import argparse +import os +import sys +import numpy +import enum +from pathlib import Path +from typing import Any, Optional, Tuple, Type +import warnings + +import numpy as np +from PySide6.QtWidgets import ( + QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, + QPushButton, QLabel, QLineEdit, QFileDialog, QTableWidget, + QTableWidgetItem, QComboBox, QMessageBox, QTabWidget, + QTextEdit, QFormLayout, + QHeaderView, QDialog, QDialogButtonBox +) +from PySide6.QtCore import Qt + +# Necessary to load the local gguf package +if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists(): + sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +import gguf +from gguf import GGUFReader, GGUFWriter, GGUFValueType, ReaderField +from gguf.constants import TokenType, RopeScalingType, PoolingType, GGMLQuantizationType + +logger = logging.getLogger("gguf-editor-gui") + +# Map of key names to enum types for automatic enum interpretation +KEY_TO_ENUM_TYPE = { + gguf.Keys.Tokenizer.TOKEN_TYPE: TokenType, + gguf.Keys.Rope.SCALING_TYPE: RopeScalingType, + gguf.Keys.LLM.POOLING_TYPE: PoolingType, + gguf.Keys.General.FILE_TYPE: GGMLQuantizationType, +} + +# Define the tokenizer keys that should be edited together +TOKENIZER_LINKED_KEYS = [ + gguf.Keys.Tokenizer.LIST, + gguf.Keys.Tokenizer.TOKEN_TYPE, + gguf.Keys.Tokenizer.SCORES +] + + +class TokenizerEditorDialog(QDialog): + def __init__(self, tokens, token_types, scores, parent=None): + super().__init__(parent) + self.setWindowTitle("Edit Tokenizer Data") + self.resize(900, 600) + + self.tokens = tokens.copy() if tokens else [] + self.token_types = token_types.copy() if token_types else [] + self.scores = scores.copy() if scores else [] + + # Ensure all arrays have the same length + max_len = max(len(self.tokens), len(self.token_types), len(self.scores)) + if len(self.tokens) < max_len: + self.tokens.extend([""] * (max_len - len(self.tokens))) + if len(self.token_types) < max_len: + self.token_types.extend([0] * (max_len - len(self.token_types))) + if len(self.scores) < max_len: + self.scores.extend([0.0] * (max_len - len(self.scores))) + + layout = QVBoxLayout(self) + + # Add filter controls + filter_layout = QHBoxLayout() + filter_layout.addWidget(QLabel("Filter:")) + self.filter_edit = QLineEdit() + self.filter_edit.setPlaceholderText("Type to filter tokens...") + self.filter_edit.textChanged.connect(self.apply_filter) + filter_layout.addWidget(self.filter_edit) + + # Add page controls + self.page_size = 100 # Show 100 items per page + self.current_page = 0 + self.total_pages = max(1, (len(self.tokens) + self.page_size - 1) // self.page_size) + + self.page_label = QLabel(f"Page 1 of {self.total_pages}") + filter_layout.addWidget(self.page_label) + + prev_page = QPushButton("Previous") + prev_page.clicked.connect(self.previous_page) + filter_layout.addWidget(prev_page) + + next_page = QPushButton("Next") + next_page.clicked.connect(self.next_page) + filter_layout.addWidget(next_page) + + layout.addLayout(filter_layout) + + # Tokenizer data table + self.tokens_table = QTableWidget() + self.tokens_table.setColumnCount(4) + self.tokens_table.setHorizontalHeaderLabels(["Index", "Token", "Type", "Score"]) + self.tokens_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents) + self.tokens_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch) + self.tokens_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents) + self.tokens_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.ResizeMode.ResizeToContents) + + layout.addWidget(self.tokens_table) + + # Controls + controls_layout = QHBoxLayout() + + add_button = QPushButton("Add Token") + add_button.clicked.connect(self.add_token) + controls_layout.addWidget(add_button) + + remove_button = QPushButton("Remove Selected") + remove_button.clicked.connect(self.remove_selected) + controls_layout.addWidget(remove_button) + + controls_layout.addStretch() + + layout.addLayout(controls_layout) + + # Buttons + buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) + buttons.accepted.connect(self.accept) + buttons.rejected.connect(self.reject) + layout.addWidget(buttons) + + # Initialize the filtered values + self.filtered_indices = list(range(len(self.tokens))) + + # Load data for the first page + self.load_page() + + def apply_filter(self): + """Filter the tokens based on the search text.""" + filter_text = self.filter_edit.text().lower() + + if not filter_text: + # No filter, show all values + self.filtered_indices = list(range(len(self.tokens))) + else: + # Apply filter + self.filtered_indices = [] + for i, token in enumerate(self.tokens): + if filter_text in str(token).lower(): + self.filtered_indices.append(i) + + # Reset to first page and reload + self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size) + self.current_page = 0 + self.page_label.setText(f"Page 1 of {self.total_pages}") + self.load_page() + + def previous_page(self): + """Go to the previous page of results.""" + if self.current_page > 0: + self.current_page -= 1 + self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}") + self.load_page() + + def next_page(self): + """Go to the next page of results.""" + if self.current_page < self.total_pages - 1: + self.current_page += 1 + self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}") + self.load_page() + + def load_page(self): + """Load the current page of tokenizer data.""" + self.tokens_table.setRowCount(0) # Clear the table + + # Calculate start and end indices for the current page + start_idx = self.current_page * self.page_size + end_idx = min(start_idx + self.page_size, len(self.filtered_indices)) + + # Pre-allocate rows for better performance + self.tokens_table.setRowCount(end_idx - start_idx) + + for row, i in enumerate(range(start_idx, end_idx)): + orig_idx = self.filtered_indices[i] + + # Index + index_item = QTableWidgetItem(str(orig_idx)) + index_item.setData(Qt.ItemDataRole.UserRole, orig_idx) # Store original index + index_item.setFlags(index_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.tokens_table.setItem(row, 0, index_item) + + # Token + token_item = QTableWidgetItem(str(self.tokens[orig_idx])) + self.tokens_table.setItem(row, 1, token_item) + + # Token Type + token_type = self.token_types[orig_idx] if orig_idx < len(self.token_types) else 0 + try: + enum_val = TokenType(token_type) + display_text = f"{enum_val.name} ({token_type})" + except (ValueError, KeyError): + display_text = f"Unknown ({token_type})" + + type_item = QTableWidgetItem(display_text) + type_item.setData(Qt.ItemDataRole.UserRole, token_type) + + # Make type cell editable with a double-click handler + type_item.setFlags(type_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.tokens_table.setItem(row, 2, type_item) + + # Score + score = self.scores[orig_idx] if orig_idx < len(self.scores) else 0.0 + score_item = QTableWidgetItem(str(score)) + self.tokens_table.setItem(row, 3, score_item) + + # Connect double-click handler for token type cells + self.tokens_table.cellDoubleClicked.connect(self.handle_cell_double_click) + + def handle_cell_double_click(self, row, column): + """Handle double-click on a cell, specifically for token type editing.""" + if column == 2: # Token Type column + orig_item = self.tokens_table.item(row, 0) + if orig_item: + orig_idx = orig_item.data(Qt.ItemDataRole.UserRole) + self.edit_token_type(row, orig_idx) + + def edit_token_type(self, row, orig_idx): + """Edit a token type using a dialog with a dropdown of all enum options.""" + current_value = self.token_types[orig_idx] if orig_idx < len(self.token_types) else 0 + + # Create a dialog with enum options + dialog = QDialog(self) + dialog.setWindowTitle("Select Token Type") + layout = QVBoxLayout(dialog) + + combo = QComboBox() + for enum_val in TokenType: + combo.addItem(f"{enum_val.name} ({enum_val.value})", enum_val.value) + + # Set current value + try: + if isinstance(current_value, int): + enum_val = TokenType(current_value) + combo.setCurrentText(f"{enum_val.name} ({current_value})") + except (ValueError, KeyError): + pass + + layout.addWidget(combo) + + buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) + buttons.accepted.connect(dialog.accept) + buttons.rejected.connect(dialog.reject) + layout.addWidget(buttons) + + if dialog.exec() == QDialog.DialogCode.Accepted: + # Get the selected value + new_value = combo.currentData() + enum_val = TokenType(new_value) + display_text = f"{enum_val.name} ({new_value})" + + # Update the display + type_item = self.tokens_table.item(row, 2) + if type_item: + type_item.setText(display_text) + type_item.setData(Qt.ItemDataRole.UserRole, new_value) + + # Update the actual value + self.token_types[orig_idx] = new_value + + def add_token(self): + """Add a new token to the end of the list.""" + # Add to the end of the arrays + self.tokens.append("") + self.token_types.append(0) # Default to normal token + self.scores.append(0.0) + + orig_idx = len(self.tokens) - 1 + + # Add to filtered indices if it matches the current filter + filter_text = self.filter_edit.text().lower() + if not filter_text or filter_text in "": + self.filtered_indices.append(orig_idx) + + # Update pagination + self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size) + + # Go to the last page to show the new item + self.current_page = self.total_pages - 1 + self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}") + + # Reload the page + self.load_page() + + def remove_selected(self): + """Remove selected tokens from all arrays.""" + selected_rows = [] + for item in self.tokens_table.selectedItems(): + row = item.row() + if row not in selected_rows: + selected_rows.append(row) + + if not selected_rows: + return + + # Get original indices in descending order to avoid index shifting + orig_indices = [] + for row in selected_rows: + orig_item = self.tokens_table.item(row, 0) + if orig_item: + orig_indices.append(orig_item.data(Qt.ItemDataRole.UserRole)) + orig_indices.sort(reverse=True) + + # Remove from all arrays + for idx in orig_indices: + if idx < len(self.tokens): + del self.tokens[idx] + if idx < len(self.token_types): + del self.token_types[idx] + if idx < len(self.scores): + del self.scores[idx] + + # Rebuild filtered_indices + self.filtered_indices = [] + filter_text = self.filter_edit.text().lower() + + for i, token in enumerate(self.tokens): + if not filter_text or filter_text in str(token).lower(): + self.filtered_indices.append(i) + + # Update pagination + self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size) + self.current_page = min(self.current_page, self.total_pages - 1) + self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}") + + # Reload the page + self.load_page() + + def get_data(self): + """Return the edited tokenizer data.""" + return self.tokens, self.token_types, self.scores + + +class ArrayEditorDialog(QDialog): + def __init__(self, array_values, element_type, key=None, parent=None): + super().__init__(parent) + self.setWindowTitle("Edit Array Values") + self.resize(700, 500) + + self.array_values = array_values + self.element_type = element_type + self.key = key + + # Get enum type for this array if applicable + self.enum_type = None + if key in KEY_TO_ENUM_TYPE and element_type == GGUFValueType.INT32: + self.enum_type = KEY_TO_ENUM_TYPE[key] + + layout = QVBoxLayout(self) + + # Add enum type information if applicable + if self.enum_type is not None: + enum_info_layout = QHBoxLayout() + enum_label = QLabel(f"Editing {self.enum_type.__name__} values:") + enum_info_layout.addWidget(enum_label) + + # Add a legend for the enum values + enum_values = ", ".join([f"{e.name}={e.value}" for e in self.enum_type]) + enum_values_label = QLabel(f"Available values: {enum_values}") + enum_values_label.setWordWrap(True) + enum_info_layout.addWidget(enum_values_label, 1) + + layout.addLayout(enum_info_layout) + + # Add search/filter controls + filter_layout = QHBoxLayout() + filter_layout.addWidget(QLabel("Filter:")) + self.filter_edit = QLineEdit() + self.filter_edit.setPlaceholderText("Type to filter values...") + self.filter_edit.textChanged.connect(self.apply_filter) + filter_layout.addWidget(self.filter_edit) + + # Add page controls for large arrays + self.page_size = 100 # Show 100 items per page + self.current_page = 0 + self.total_pages = max(1, (len(array_values) + self.page_size - 1) // self.page_size) + + self.page_label = QLabel(f"Page 1 of {self.total_pages}") + filter_layout.addWidget(self.page_label) + + prev_page = QPushButton("Previous") + prev_page.clicked.connect(self.previous_page) + filter_layout.addWidget(prev_page) + + next_page = QPushButton("Next") + next_page.clicked.connect(self.next_page) + filter_layout.addWidget(next_page) + + layout.addLayout(filter_layout) + + # Array items table + self.items_table = QTableWidget() + + # Set up columns based on whether we have an enum type + if self.enum_type is not None: + self.items_table.setColumnCount(3) + self.items_table.setHorizontalHeaderLabels(["Index", "Value", "Actions"]) + self.items_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents) + self.items_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch) + self.items_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents) + else: + self.items_table.setColumnCount(2) + self.items_table.setHorizontalHeaderLabels(["Index", "Value"]) + self.items_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents) + self.items_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch) + + layout.addWidget(self.items_table) + + # Controls + controls_layout = QHBoxLayout() + + add_button = QPushButton("Add Item") + add_button.clicked.connect(self.add_item) + controls_layout.addWidget(add_button) + + remove_button = QPushButton("Remove Selected") + remove_button.clicked.connect(self.remove_selected) + controls_layout.addWidget(remove_button) + + # Add bulk edit button for enum arrays + if self.enum_type is not None: + bulk_edit_button = QPushButton("Bulk Edit Selected") + bulk_edit_button.clicked.connect(self.bulk_edit_selected) + controls_layout.addWidget(bulk_edit_button) + + controls_layout.addStretch() + + layout.addLayout(controls_layout) + + # Buttons + buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) + buttons.accepted.connect(self.accept) + buttons.rejected.connect(self.reject) + layout.addWidget(buttons) + + # Initialize the filtered values + self.filtered_indices = list(range(len(self.array_values))) + + # Load array values for the first page + self.load_page() + + def apply_filter(self): + """Filter the array values based on the search text.""" + filter_text = self.filter_edit.text().lower() + + if not filter_text: + # No filter, show all values + self.filtered_indices = list(range(len(self.array_values))) + else: + # Apply filter + self.filtered_indices = [] + for i, value in enumerate(self.array_values): + # For enum values, search in both name and value + if self.enum_type is not None and isinstance(value, int): + try: + enum_val = self.enum_type(value) + display_text = f"{enum_val.name} ({value})".lower() + if filter_text in display_text: + self.filtered_indices.append(i) + except (ValueError, KeyError): + # If not a valid enum value, just check the raw value + if filter_text in str(value).lower(): + self.filtered_indices.append(i) + else: + # For non-enum values, just check the string representation + if filter_text in str(value).lower(): + self.filtered_indices.append(i) + + # Reset to first page and reload + self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size) + self.current_page = 0 + self.page_label.setText(f"Page 1 of {self.total_pages}") + self.load_page() + + def previous_page(self): + """Go to the previous page of results.""" + if self.current_page > 0: + self.current_page -= 1 + self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}") + self.load_page() + + def next_page(self): + """Go to the next page of results.""" + if self.current_page < self.total_pages - 1: + self.current_page += 1 + self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}") + self.load_page() + + def load_page(self): + """Load the current page of array values.""" + self.items_table.setRowCount(0) # Clear the table + + # Calculate start and end indices for the current page + start_idx = self.current_page * self.page_size + end_idx = min(start_idx + self.page_size, len(self.filtered_indices)) + + # Pre-allocate rows for better performance + self.items_table.setRowCount(end_idx - start_idx) + + for row, i in enumerate(range(start_idx, end_idx)): + orig_idx = self.filtered_indices[i] + value = self.array_values[orig_idx] + + # Index + index_item = QTableWidgetItem(str(orig_idx)) + index_item.setData(Qt.ItemDataRole.UserRole, orig_idx) # Store original index + index_item.setFlags(index_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.items_table.setItem(row, 0, index_item) + + # Value + if self.enum_type is not None: + # Display enum value and name + try: + if isinstance(value, (int, numpy.signedinteger)): + enum_val = self.enum_type(value) + display_text = f"{enum_val.name} ({value})" + else: + display_text = str(value) + except (ValueError, KeyError): + display_text = f"Unknown ({value})" + + # Store the enum value in the item + value_item = QTableWidgetItem(display_text) + value_item.setData(Qt.ItemDataRole.UserRole, value) + value_item.setFlags(value_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.items_table.setItem(row, 1, value_item) + + # Add an edit button in a separate column + edit_button = QPushButton("Edit") + edit_button.setProperty("row", row) + edit_button.clicked.connect(self.edit_array_enum_value) + + # Create a widget to hold the button + button_widget = QWidget() + button_layout = QHBoxLayout(button_widget) + button_layout.setContentsMargins(2, 2, 2, 2) + button_layout.addWidget(edit_button) + button_layout.addStretch() + + self.items_table.setCellWidget(row, 2, button_widget) + else: + value_item = QTableWidgetItem(str(value)) + self.items_table.setItem(row, 1, value_item) + + def edit_array_enum_value(self): + """Handle editing an enum value in the array editor.""" + button = self.sender() + row = button.property("row") + + # Get the original index from the table item + orig_item = self.items_table.item(row, 0) + new_item = self.items_table.item(row, 1) + if orig_item and new_item and self.enum_type and self.edit_enum_value(row, self.enum_type): + orig_idx = orig_item.data(Qt.ItemDataRole.UserRole) + new_value = new_item.data(Qt.ItemDataRole.UserRole) + # Update the stored value in the array + if isinstance(new_value, (int, float, str, bool)): + self.array_values[orig_idx] = new_value + + def bulk_edit_selected(self): + """Edit multiple enum values at once.""" + if not self.enum_type: + return + + selected_rows = set() + for item in self.items_table.selectedItems(): + selected_rows.add(item.row()) + + if not selected_rows: + QMessageBox.information(self, "No Selection", "Please select at least one row to edit.") + return + + # Create a dialog with enum options + dialog = QDialog(self) + dialog.setWindowTitle(f"Bulk Edit {self.enum_type.__name__} Values") + layout = QVBoxLayout(dialog) + + layout.addWidget(QLabel(f"Set {len(selected_rows)} selected items to:")) + + combo = QComboBox() + for enum_val in self.enum_type: + combo.addItem(f"{enum_val.name} ({enum_val.value})", enum_val.value) + + layout.addWidget(combo) + + buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) + buttons.accepted.connect(dialog.accept) + buttons.rejected.connect(dialog.reject) + layout.addWidget(buttons) + + if dialog.exec() == QDialog.DialogCode.Accepted: + # Get the selected value + new_value = combo.currentData() + enum_val = self.enum_type(new_value) + display_text = f"{enum_val.name} ({new_value})" + + # Update all selected rows + for row in selected_rows: + orig_item = self.items_table.item(row, 0) + new_item = self.items_table.item(row, 1) + if orig_item and new_item: + orig_idx = orig_item.data(Qt.ItemDataRole.UserRole) + self.array_values[orig_idx] = new_value + + # Update the display + new_item.setText(display_text) + new_item.setData(Qt.ItemDataRole.UserRole, new_value) + + def add_item(self): + # Add to the end of the array + orig_idx = len(self.array_values) + + # Add default value based on type + if self.enum_type is not None: + # Default to first enum value + default_value = list(self.enum_type)[0].value + self.array_values.append(default_value) + else: + if self.element_type == GGUFValueType.STRING: + self.array_values.append("") + else: + self.array_values.append(0) + + # Add to filtered indices if it matches the current filter + self.filtered_indices.append(orig_idx) + + # Update pagination + self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size) + + # Go to the last page to show the new item + self.current_page = self.total_pages - 1 + self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}") + + # Reload the page + self.load_page() + + def remove_selected(self): + selected_rows = [] + for item in self.items_table.selectedItems(): + row = item.row() + if row not in selected_rows: + selected_rows.append(row) + + if not selected_rows: + return + + # Get original indices in descending order to avoid index shifting + orig_indices = list() + for row in selected_rows: + orig_item = self.items_table.item(row, 0) + if orig_item: + orig_indices.append(orig_item.data(Qt.ItemDataRole.UserRole)) + orig_indices.sort(reverse=True) + + # Remove from array_values + for idx in orig_indices: + del self.array_values[idx] + + # Rebuild filtered_indices + self.filtered_indices = [] + filter_text = self.filter_edit.text().lower() + + for i, value in enumerate(self.array_values): + if not filter_text: + self.filtered_indices.append(i) + else: + # Apply filter + if self.enum_type is not None and isinstance(value, int): + try: + enum_val = self.enum_type(value) + display_text = f"{enum_val.name} ({value})".lower() + if filter_text in display_text: + self.filtered_indices.append(i) + except (ValueError, KeyError): + if filter_text in str(value).lower(): + self.filtered_indices.append(i) + else: + if filter_text in str(value).lower(): + self.filtered_indices.append(i) + + # Update pagination + self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size) + self.current_page = min(self.current_page, self.total_pages - 1) + self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}") + + # Reload the page + self.load_page() + + def edit_enum_value(self, row: int, enum_type: Type[enum.Enum]): + """Edit an enum value using a dialog with a dropdown of all enum options.""" + # Get the original index from the table item + orig_item = self.items_table.item(row, 0) + if orig_item: + orig_idx = orig_item.data(Qt.ItemDataRole.UserRole) + else: + return + current_value = self.array_values[orig_idx] + + # Create a dialog with enum options + dialog = QDialog(self) + dialog.setWindowTitle(f"Select {enum_type.__name__} Value") + layout = QVBoxLayout(dialog) + + # Add description + description = QLabel(f"Select a {enum_type.__name__} value:") + layout.addWidget(description) + + # Use a combo box for quick selection + combo = QComboBox() + for enum_val in enum_type: + combo.addItem(f"{enum_val.name} ({enum_val.value})", enum_val.value) + + # Set current value + try: + if isinstance(current_value, int): + enum_val = enum_type(current_value) + combo.setCurrentText(f"{enum_val.name} ({current_value})") + except (ValueError, KeyError): + pass + + layout.addWidget(combo) + + buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) + buttons.accepted.connect(dialog.accept) + buttons.rejected.connect(dialog.reject) + layout.addWidget(buttons) + + if dialog.exec() == QDialog.DialogCode.Accepted: + # Update the value display and stored data + new_value = combo.currentData() + enum_val = enum_type(new_value) + display_text = f"{enum_val.name} ({new_value})" + + new_item = self.items_table.item(row, 1) + if new_item: + new_item.setText(display_text) + new_item.setData(Qt.ItemDataRole.UserRole, new_value) + + # Update the actual array value + self.array_values[orig_idx] = new_value + return True + return False + + def get_array_values(self): + # The array_values list is kept up-to-date as edits are made + return self.array_values + + +class AddMetadataDialog(QDialog): + def __init__(self, parent=None): + super().__init__(parent) + self.setWindowTitle("Add Metadata") + self.resize(400, 200) + + layout = QVBoxLayout(self) + + form_layout = QFormLayout() + + self.key_edit = QLineEdit() + form_layout.addRow("Key:", self.key_edit) + + self.type_combo = QComboBox() + for value_type in GGUFValueType: + if value_type != GGUFValueType.ARRAY: # Skip array type for simplicity + self.type_combo.addItem(value_type.name, value_type) + form_layout.addRow("Type:", self.type_combo) + + self.value_edit = QTextEdit() + form_layout.addRow("Value:", self.value_edit) + + layout.addLayout(form_layout) + + buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) + buttons.accepted.connect(self.accept) + buttons.rejected.connect(self.reject) + layout.addWidget(buttons) + + def get_data(self) -> Tuple[str, GGUFValueType, Any]: + key = self.key_edit.text() + value_type = self.type_combo.currentData() + value_text = self.value_edit.toPlainText() + + # Convert value based on type + if value_type == GGUFValueType.UINT8: + value = np.uint8(int(value_text)) + elif value_type == GGUFValueType.INT8: + value = np.int8(int(value_text)) + elif value_type == GGUFValueType.UINT16: + value = np.uint16(int(value_text)) + elif value_type == GGUFValueType.INT16: + value = np.int16(int(value_text)) + elif value_type == GGUFValueType.UINT32: + value = np.uint32(int(value_text)) + elif value_type == GGUFValueType.INT32: + value = np.int32(int(value_text)) + elif value_type == GGUFValueType.FLOAT32: + value = np.float32(float(value_text)) + elif value_type == GGUFValueType.BOOL: + value = value_text.lower() in ('true', 'yes', '1') + elif value_type == GGUFValueType.STRING: + value = value_text + else: + value = value_text + + return key, value_type, value + + +class GGUFEditorWindow(QMainWindow): + def __init__(self): + super().__init__() + + self.setWindowTitle("GGUF Editor") + self.resize(1000, 800) + + self.current_file = None + self.reader = None + self.modified = False + self.metadata_changes = {} # Store changes to apply when saving + self.metadata_to_remove = set() # Store keys to remove when saving + self.on_metadata_changed_is_connected = False + + self.setup_ui() + + def setup_ui(self): + central_widget = QWidget() + self.setCentralWidget(central_widget) + + main_layout = QVBoxLayout(central_widget) + + # File controls + file_layout = QHBoxLayout() + + self.file_path_edit = QLineEdit() + self.file_path_edit.setReadOnly(True) + file_layout.addWidget(self.file_path_edit) + + open_button = QPushButton("Open GGUF") + open_button.clicked.connect(self.open_file) + file_layout.addWidget(open_button) + + save_button = QPushButton("Save As...") + save_button.clicked.connect(self.save_file) + file_layout.addWidget(save_button) + + main_layout.addLayout(file_layout) + + # Tabs for different views + self.tabs = QTabWidget() + + # Metadata tab + self.metadata_tab = QWidget() + metadata_layout = QVBoxLayout(self.metadata_tab) + + # Metadata table + self.metadata_table = QTableWidget() + self.metadata_table.setColumnCount(4) + self.metadata_table.setHorizontalHeaderLabels(["Key", "Type", "Value", "Actions"]) + self.metadata_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch) + self.metadata_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.ResizeToContents) + self.metadata_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeMode.Stretch) + self.metadata_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.ResizeMode.ResizeToContents) + metadata_layout.addWidget(self.metadata_table) + + # Metadata controls + metadata_controls = QHBoxLayout() + + add_metadata_button = QPushButton("Add Metadata") + add_metadata_button.clicked.connect(self.add_metadata) + metadata_controls.addWidget(add_metadata_button) + + metadata_controls.addStretch() + + metadata_layout.addLayout(metadata_controls) + + # Tensors tab + self.tensors_tab = QWidget() + tensors_layout = QVBoxLayout(self.tensors_tab) + + self.tensors_table = QTableWidget() + self.tensors_table.setColumnCount(5) + self.tensors_table.setHorizontalHeaderLabels(["Name", "Type", "Shape", "Elements", "Size (bytes)"]) + self.tensors_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch) + self.tensors_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.ResizeToContents) + self.tensors_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents) + self.tensors_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.ResizeMode.ResizeToContents) + self.tensors_table.horizontalHeader().setSectionResizeMode(4, QHeaderView.ResizeMode.ResizeToContents) + tensors_layout.addWidget(self.tensors_table) + + # Add tabs to tab widget + self.tabs.addTab(self.metadata_tab, "Metadata") + self.tabs.addTab(self.tensors_tab, "Tensors") + + main_layout.addWidget(self.tabs) + + # Status bar + self.statusBar().showMessage("Ready") + + def load_file(self, file_path): + """Load a GGUF file by path""" + try: + self.statusBar().showMessage(f"Loading {file_path}...") + QApplication.processEvents() + + self.reader = GGUFReader(file_path, 'r') + self.current_file = file_path + self.file_path_edit.setText(file_path) + + self.load_metadata() + self.load_tensors() + + self.metadata_changes = {} + self.metadata_to_remove = set() + self.modified = False + + self.statusBar().showMessage(f"Loaded {file_path}") + return True + except Exception as e: + QMessageBox.critical(self, "Error", f"Failed to open file: {str(e)}") + self.statusBar().showMessage("Error loading file") + return False + + def open_file(self): + file_path, _ = QFileDialog.getOpenFileName( + self, "Open GGUF File", "", "GGUF Files (*.gguf);;All Files (*)" + ) + + if not file_path: + return + + self.load_file(file_path) + + def load_metadata(self): + self.metadata_table.setRowCount(0) + + if not self.reader: + return + + # Disconnect to prevent triggering during loading + if self.on_metadata_changed_is_connected: + with warnings.catch_warnings(): + warnings.filterwarnings('ignore') + self.metadata_table.itemChanged.disconnect(self.on_metadata_changed) + self.on_metadata_changed_is_connected = False + + for i, (key, field) in enumerate(self.reader.fields.items()): + self.metadata_table.insertRow(i) + + # Key + key_item = QTableWidgetItem(key) + key_item.setFlags(key_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.metadata_table.setItem(i, 0, key_item) + + # Type + if not field.types: + type_str = "N/A" + elif field.types[0] == GGUFValueType.ARRAY: + nest_count = len(field.types) - 1 + element_type = field.types[-1].name + # Check if this is an enum array + enum_type = self.get_enum_for_key(key) + if enum_type is not None and field.types[-1] == GGUFValueType.INT32: + element_type = enum_type.__name__ + type_str = '[' * nest_count + element_type + ']' * nest_count + else: + type_str = str(field.types[0].name) + # Check if this is an enum field + enum_type = self.get_enum_for_key(key) + if enum_type is not None and field.types[0] == GGUFValueType.INT32: + type_str = enum_type.__name__ + + type_item = QTableWidgetItem(type_str) + type_item.setFlags(type_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.metadata_table.setItem(i, 1, type_item) + + # Value + value_str = self.format_field_value(field) + value_item = QTableWidgetItem(value_str) + + # Make only simple values editable + if len(field.types) == 1 and field.types[0] != GGUFValueType.ARRAY: + value_item.setFlags(value_item.flags() | Qt.ItemFlag.ItemIsEditable) + else: + value_item.setFlags(value_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + + self.metadata_table.setItem(i, 2, value_item) + + # Actions + actions_widget = QWidget() + actions_layout = QHBoxLayout(actions_widget) + actions_layout.setContentsMargins(2, 2, 2, 2) + + # Add Edit button for arrays and enum fields + if field.types and field.types[0] == GGUFValueType.ARRAY: + edit_button = QPushButton("Edit") + edit_button.setProperty("row", i) + edit_button.setProperty("key", key) + edit_button.clicked.connect(self.edit_array_metadata) + actions_layout.addWidget(edit_button) + + # Add special label for tokenizer linked fields + if key in TOKENIZER_LINKED_KEYS: + edit_button.setText("Edit Tokenizer") + edit_button.setToolTip("Edit all tokenizer data together") + elif len(field.types) == 1 and self.get_enum_for_key(key) is not None: + edit_button = QPushButton("Edit") + edit_button.setProperty("row", i) + edit_button.setProperty("key", key) + edit_button.clicked.connect(self.edit_metadata_enum) + actions_layout.addWidget(edit_button) + + remove_button = QPushButton("Remove") + remove_button.setProperty("row", i) + remove_button.setProperty("key", key) + remove_button.clicked.connect(self.remove_metadata) + actions_layout.addWidget(remove_button) + + self.metadata_table.setCellWidget(i, 3, actions_widget) + + # Reconnect after loading + self.metadata_table.itemChanged.connect(self.on_metadata_changed) + self.on_metadata_changed_is_connected = True + + def extract_array_values(self, field: ReaderField) -> list: + """Extract all values from an array field.""" + if not field.types or field.types[0] != GGUFValueType.ARRAY: + return [] + + curr_type = field.types[1] + array_values = [] + total_elements = len(field.data) + + if curr_type == GGUFValueType.STRING: + for element_pos in range(total_elements): + value_string = str(bytes(field.parts[-1 - (total_elements - element_pos - 1) * 2]), encoding='utf-8') + array_values.append(value_string) + elif self.reader and curr_type in self.reader.gguf_scalar_to_np: + for element_pos in range(total_elements): + array_values.append(field.parts[-1 - (total_elements - element_pos - 1)][0]) + + return array_values + + def get_enum_for_key(self, key: str) -> Optional[Type[enum.Enum]]: + """Get the enum type for a given key if it exists.""" + return KEY_TO_ENUM_TYPE.get(key) + + def format_enum_value(self, value: Any, enum_type: Type[enum.Enum]) -> str: + """Format a value as an enum if possible.""" + try: + if isinstance(value, (int, str)): + enum_value = enum_type(value) + return f"{enum_value.name} ({value})" + except (ValueError, KeyError): + pass + return str(value) + + def format_field_value(self, field: ReaderField) -> str: + if not field.types: + return "N/A" + + if len(field.types) == 1: + curr_type = field.types[0] + if curr_type == GGUFValueType.STRING: + return str(bytes(field.parts[-1]), encoding='utf-8') + elif self.reader and curr_type in self.reader.gguf_scalar_to_np: + value = field.parts[-1][0] + # Check if this field has an enum type + enum_type = self.get_enum_for_key(field.name) + if enum_type is not None: + return self.format_enum_value(value, enum_type) + return str(value) + + if field.types[0] == GGUFValueType.ARRAY: + array_values = self.extract_array_values(field) + render_element = min(5, len(array_values)) + + # Get enum type for this array if applicable + enum_type = self.get_enum_for_key(field.name) + + if enum_type is not None: + array_elements = [] + for i in range(render_element): + array_elements.append(self.format_enum_value(array_values[i], enum_type)) + else: + array_elements = [str(array_values[i]) for i in range(render_element)] + + return f"[ {', '.join(array_elements).strip()}{', ...' if len(array_values) > len(array_elements) else ''} ]" + + return "Complex value" + + def load_tensors(self): + self.tensors_table.setRowCount(0) + + if not self.reader: + return + + for i, tensor in enumerate(self.reader.tensors): + self.tensors_table.insertRow(i) + + # Name + name_item = QTableWidgetItem(tensor.name) + name_item.setFlags(name_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.tensors_table.setItem(i, 0, name_item) + + # Type + type_item = QTableWidgetItem(tensor.tensor_type.name) + type_item.setFlags(type_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.tensors_table.setItem(i, 1, type_item) + + # Shape + shape_str = " × ".join(str(d) for d in tensor.shape) + shape_item = QTableWidgetItem(shape_str) + shape_item.setFlags(shape_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.tensors_table.setItem(i, 2, shape_item) + + # Elements + elements_item = QTableWidgetItem(str(tensor.n_elements)) + elements_item.setFlags(elements_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.tensors_table.setItem(i, 3, elements_item) + + # Size + size_item = QTableWidgetItem(f"{tensor.n_bytes:,}") + size_item.setFlags(size_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.tensors_table.setItem(i, 4, size_item) + + def on_metadata_changed(self, item): + if item.column() != 2: # Only handle value column changes + return + + row = item.row() + orig_item = self.metadata_table.item(row, 0) + key = None + if orig_item: + key = orig_item.text() + new_value = item.text() + + field = None + if self.reader and key: + field = self.reader.get_field(key) + if not field or not field.types or not key: + return + + value_type = field.types[0] + + # Check if this is an enum field + enum_type = self.get_enum_for_key(key) + if enum_type is not None and value_type == GGUFValueType.INT32: + # Try to parse the enum value from the text + try: + # Check if it's a name + try: + enum_val = enum_type[new_value] + converted_value = enum_val.value + except (KeyError, AttributeError): + # Check if it's a number or "NAME (value)" format + if '(' in new_value and ')' in new_value: + # Extract the value from "NAME (value)" format + value_part = new_value.split('(')[1].split(')')[0].strip() + converted_value = int(value_part) + else: + # Try to convert directly to int + converted_value = int(new_value) + + # Validate that it's a valid enum value + enum_type(converted_value) + + # Store the change + self.metadata_changes[key] = (value_type, converted_value) + self.modified = True + + # Update display with formatted enum value + formatted_value = self.format_enum_value(converted_value, enum_type) + item.setText(formatted_value) + + self.statusBar().showMessage(f"Changed {key} to {formatted_value}") + return + except (ValueError, KeyError) as e: + QMessageBox.warning( + self, + f"Invalid Enum Value ({e})", + f"'{new_value}' is not a valid {enum_type.__name__} value.\n" + f"Valid values are: {', '.join(v.name for v in enum_type)}") + + # Revert to original value + original_value = self.format_field_value(field) + item.setText(original_value) + return + + try: + # Convert the string value to the appropriate type + if value_type == GGUFValueType.UINT8: + converted_value = np.uint8(int(new_value)) + elif value_type == GGUFValueType.INT8: + converted_value = np.int8(int(new_value)) + elif value_type == GGUFValueType.UINT16: + converted_value = np.uint16(int(new_value)) + elif value_type == GGUFValueType.INT16: + converted_value = np.int16(int(new_value)) + elif value_type == GGUFValueType.UINT32: + converted_value = np.uint32(int(new_value)) + elif value_type == GGUFValueType.INT32: + converted_value = np.int32(int(new_value)) + elif value_type == GGUFValueType.FLOAT32: + converted_value = np.float32(float(new_value)) + elif value_type == GGUFValueType.BOOL: + converted_value = new_value.lower() in ('true', 'yes', '1') + elif value_type == GGUFValueType.STRING: + converted_value = new_value + else: + # Unsupported type for editing + return + + # Store the change + self.metadata_changes[key] = (value_type, converted_value) + self.modified = True + + self.statusBar().showMessage(f"Changed {key} to {new_value}") + except ValueError: + QMessageBox.warning(self, "Invalid Value", f"The value '{new_value}' is not valid for type {value_type.name}") + + # Revert to original value + original_value = self.format_field_value(field) + item.setText(original_value) + + def remove_metadata(self): + button = self.sender() + key = button.property("key") + row = button.property("row") + + reply = QMessageBox.question( + self, "Confirm Removal", + f"Are you sure you want to remove the metadata key '{key}'?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, QMessageBox.StandardButton.No + ) + + if reply == QMessageBox.StandardButton.Yes: + self.metadata_table.removeRow(row) + self.metadata_to_remove.add(key) + + # If we previously had changes for this key, remove them + if key in self.metadata_changes: + del self.metadata_changes[key] + + self.modified = True + self.statusBar().showMessage(f"Marked {key} for removal") + + def edit_metadata_enum(self): + """Edit an enum metadata field.""" + button = self.sender() + key = button.property("key") + row = button.property("row") + + field = None + if self.reader: + field = self.reader.get_field(key) + if not field or not field.types: + return + + enum_type = self.get_enum_for_key(key) + if enum_type is None: + return + + # Get current value + current_value = field.contents() + + # Create a dialog with enum options + dialog = QDialog(self) + dialog.setWindowTitle(f"Select {enum_type.__name__} Value") + layout = QVBoxLayout(dialog) + + combo = QComboBox() + for enum_val in enum_type: + combo.addItem(f"{enum_val.name} ({enum_val.value})", enum_val.value) + + # Set current value + try: + if isinstance(current_value, (int, str)): + enum_val = enum_type(current_value) + combo.setCurrentText(f"{enum_val.name} ({current_value})") + except (ValueError, KeyError): + pass + + layout.addWidget(combo) + + buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) + buttons.accepted.connect(dialog.accept) + buttons.rejected.connect(dialog.reject) + layout.addWidget(buttons) + + if dialog.exec() == QDialog.DialogCode.Accepted: + # Get the selected value + new_value = combo.currentData() + enum_val = enum_type(new_value) + + # Store the change + self.metadata_changes[key] = (field.types[0], new_value) + self.modified = True + + # Update display + display_text = f"{enum_val.name} ({new_value})" + target_item = self.metadata_table.item(row, 2) + if target_item: + target_item.setText(display_text) + + self.statusBar().showMessage(f"Changed {key} to {display_text}") + + def edit_array_metadata(self): + button = self.sender() + key = button.property("key") + row = button.property("row") + + # Check if this is one of the linked tokenizer keys + if key in TOKENIZER_LINKED_KEYS: + self.edit_tokenizer_metadata(key) + return + + field = None + if self.reader: + field = self.reader.get_field(key) + if not field or not field.types or field.types[0] != GGUFValueType.ARRAY: + return + + # Get array element type + element_type = field.types[1] + + # Extract array values + array_values = self.extract_array_values(field) + + # Open array editor dialog + dialog = ArrayEditorDialog(array_values, element_type, key, self) + if dialog.exec() == QDialog.DialogCode.Accepted: + new_values = dialog.get_array_values() + + # Store the change + self.metadata_changes[key] = (GGUFValueType.ARRAY, (element_type, new_values)) + self.modified = True + + # Update display + enum_type = self.get_enum_for_key(key) + if enum_type is not None and element_type == GGUFValueType.INT32: + value_str = f"[ {', '.join(self.format_enum_value(v, enum_type) for v in new_values[:5])}{', ...' if len(new_values) > 5 else ''} ]" + else: + value_str = f"[ {', '.join(str(v) for v in new_values[:5])}{', ...' if len(new_values) > 5 else ''} ]" + target_item = self.metadata_table.item(row, 2) + if target_item: + target_item.setText(value_str) + + self.statusBar().showMessage(f"Updated array values for {key}") + + def edit_tokenizer_metadata(self, trigger_key): + """Edit the linked tokenizer metadata arrays together.""" + if not self.reader: + return + + # Get all three fields + tokens_field = self.reader.get_field(gguf.Keys.Tokenizer.LIST) + token_types_field = self.reader.get_field(gguf.Keys.Tokenizer.TOKEN_TYPE) + scores_field = self.reader.get_field(gguf.Keys.Tokenizer.SCORES) + + # Extract values from each field + tokens = self.extract_array_values(tokens_field) if tokens_field else [] + token_types = self.extract_array_values(token_types_field) if token_types_field else [] + scores = self.extract_array_values(scores_field) if scores_field else [] + + # Apply any pending changes + if gguf.Keys.Tokenizer.LIST in self.metadata_changes: + _, (_, tokens) = self.metadata_changes[gguf.Keys.Tokenizer.LIST] + if gguf.Keys.Tokenizer.TOKEN_TYPE in self.metadata_changes: + _, (_, token_types) = self.metadata_changes[gguf.Keys.Tokenizer.TOKEN_TYPE] + if gguf.Keys.Tokenizer.SCORES in self.metadata_changes: + _, (_, scores) = self.metadata_changes[gguf.Keys.Tokenizer.SCORES] + + # Open the tokenizer editor dialog + dialog = TokenizerEditorDialog(tokens, token_types, scores, self) + if dialog.exec() == QDialog.DialogCode.Accepted: + new_tokens, new_token_types, new_scores = dialog.get_data() + + # Store changes for all three arrays + if tokens_field: + self.metadata_changes[gguf.Keys.Tokenizer.LIST] = ( + GGUFValueType.ARRAY, + (tokens_field.types[1], new_tokens) + ) + + if token_types_field: + self.metadata_changes[gguf.Keys.Tokenizer.TOKEN_TYPE] = ( + GGUFValueType.ARRAY, + (token_types_field.types[1], new_token_types) + ) + + if scores_field: + self.metadata_changes[gguf.Keys.Tokenizer.SCORES] = ( + GGUFValueType.ARRAY, + (scores_field.types[1], new_scores) + ) + + self.modified = True + + # Update display for all three fields + self.update_tokenizer_display(gguf.Keys.Tokenizer.LIST, new_tokens) + self.update_tokenizer_display(gguf.Keys.Tokenizer.TOKEN_TYPE, new_token_types) + self.update_tokenizer_display(gguf.Keys.Tokenizer.SCORES, new_scores) + + self.statusBar().showMessage("Updated tokenizer data") + + def update_tokenizer_display(self, key, values): + """Update the display of a tokenizer field in the metadata table.""" + for row in range(self.metadata_table.rowCount()): + key_item = self.metadata_table.item(row, 0) + if key_item and key_item.text() == key: + value_str = f"[ {', '.join(str(v) for v in values[:5])}{', ...' if len(values) > 5 else ''} ]" + value_item = self.metadata_table.item(row, 2) + if value_item: + value_item.setText(value_str) + break + + def add_metadata(self): + dialog = AddMetadataDialog(self) + if dialog.exec() == QDialog.DialogCode.Accepted: + key, value_type, value = dialog.get_data() + + if not key: + QMessageBox.warning(self, "Invalid Key", "Key cannot be empty") + return + + # Check if key already exists + for row in range(self.metadata_table.rowCount()): + orig_item = self.metadata_table.item(row, 0) + if orig_item and orig_item.text() == key: + QMessageBox.warning(self, "Duplicate Key", f"Key '{key}' already exists") + return + + # Add to table + row = self.metadata_table.rowCount() + self.metadata_table.insertRow(row) + + # Key + key_item = QTableWidgetItem(key) + key_item.setFlags(key_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.metadata_table.setItem(row, 0, key_item) + + # Type + type_item = QTableWidgetItem(value_type.name) + type_item.setFlags(type_item.flags() & ~Qt.ItemFlag.ItemIsEditable) + self.metadata_table.setItem(row, 1, type_item) + + # Value + value_item = QTableWidgetItem(str(value)) + value_item.setFlags(value_item.flags() | Qt.ItemFlag.ItemIsEditable) + self.metadata_table.setItem(row, 2, value_item) + + # Actions + actions_widget = QWidget() + actions_layout = QHBoxLayout(actions_widget) + actions_layout.setContentsMargins(2, 2, 2, 2) + + remove_button = QPushButton("Remove") + remove_button.setProperty("row", row) + remove_button.setProperty("key", key) + remove_button.clicked.connect(self.remove_metadata) + actions_layout.addWidget(remove_button) + + self.metadata_table.setCellWidget(row, 3, actions_widget) + + # Store the change + self.metadata_changes[key] = (value_type, value) + self.modified = True + + self.statusBar().showMessage(f"Added new metadata key {key}") + + def save_file(self): + if not self.reader: + QMessageBox.warning(self, "No File Open", "Please open a GGUF file first") + return + + if not self.modified and not self.metadata_changes and not self.metadata_to_remove: + QMessageBox.information(self, "No Changes", "No changes to save") + return + + file_path, _ = QFileDialog.getSaveFileName( + self, "Save GGUF File As", "", "GGUF Files (*.gguf);;All Files (*)" + ) + + if not file_path: + return + + try: + self.statusBar().showMessage(f"Saving to {file_path}...") + QApplication.processEvents() + + # Get architecture and endianness from the original file + arch = 'unknown' + field = self.reader.get_field(gguf.Keys.General.ARCHITECTURE) + if field: + arch = field.contents() + + # Create writer + writer = GGUFWriter(file_path, arch=arch, endianess=self.reader.endianess) + + # Get alignment if present + alignment = None + field = self.reader.get_field(gguf.Keys.General.ALIGNMENT) + if field: + alignment = field.contents() + if alignment is not None: + writer.data_alignment = alignment + + # Copy metadata with changes + for field in self.reader.fields.values(): + # Skip virtual fields and fields written by GGUFWriter + if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'): + continue + + # Skip fields marked for removal + if field.name in self.metadata_to_remove: + continue + + # Apply changes if any + sub_type = None + if field.name in self.metadata_changes: + value_type, value = self.metadata_changes[field.name] + if value_type == GGUFValueType.ARRAY: + # Handle array values + sub_type, value = value + else: + # Copy original value + value = field.contents() + value_type = field.types[0] + if value_type == GGUFValueType.ARRAY: + sub_type = field.types[-1] + + if value is not None: + writer.add_key_value(field.name, value, value_type, sub_type=sub_type) + + # Add new metadata + for key, (value_type, value) in self.metadata_changes.items(): + # Skip if the key already existed (we handled it above) + if self.reader.get_field(key) is not None: + continue + + sub_type = None + if value_type == GGUFValueType.ARRAY: + # Handle array values + sub_type, value = value + + writer.add_key_value(key, value, value_type, sub_type=sub_type) + + # Add tensors (including data) + for tensor in self.reader.tensors: + writer.add_tensor(tensor.name, tensor.data, raw_shape=tensor.data.shape, raw_dtype=tensor.tensor_type) + + # Write header and metadata + writer.open_output_file(Path(file_path)) + writer.write_header_to_file() + writer.write_kv_data_to_file() + + # Write tensor data using the optimized method + writer.write_tensors_to_file(progress=False) + + writer.close() + + self.statusBar().showMessage(f"Saved to {file_path}") + + # Ask if user wants to open the new file + reply = QMessageBox.question( + self, "Open Saved File", + "Would you like to open the newly saved file?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, QMessageBox.StandardButton.Yes + ) + + if reply == QMessageBox.StandardButton.Yes: + self.reader = GGUFReader(file_path, 'r') + self.current_file = file_path + self.file_path_edit.setText(file_path) + + self.load_metadata() + self.load_tensors() + + self.metadata_changes = {} + self.metadata_to_remove = set() + self.modified = False + + except Exception as e: + QMessageBox.critical(self, "Error", f"Failed to save file: {str(e)}") + self.statusBar().showMessage("Error saving file") + + +def main() -> None: + parser = argparse.ArgumentParser(description="GUI GGUF Editor") + parser.add_argument("model_path", nargs="?", help="path to GGUF model file to load at startup") + parser.add_argument("--verbose", action="store_true", help="increase output verbosity") + + args = parser.parse_args() + + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) + + app = QApplication(sys.argv) + window = GGUFEditorWindow() + window.show() + + # Load model if specified + if args.model_path: + if os.path.isfile(args.model_path) and args.model_path.endswith('.gguf'): + window.load_file(args.model_path) + else: + logger.error(f"Invalid model path: {args.model_path}") + QMessageBox.warning( + window, + "Invalid Model Path", + f"The specified file does not exist or is not a GGUF file: {args.model_path}") + + sys.exit(app.exec()) + + +if __name__ == '__main__': + main() diff --git a/gguf-py/gguf/scripts/gguf_new_metadata.py b/gguf-py/gguf/scripts/gguf_new_metadata.py index 7aff6c925..63f230034 100755 --- a/gguf-py/gguf/scripts/gguf_new_metadata.py +++ b/gguf-py/gguf/scripts/gguf_new_metadata.py @@ -24,6 +24,7 @@ class MetadataDetails(NamedTuple): type: gguf.GGUFValueType value: Any description: str = '' + sub_type: gguf.GGUFValueType | None = None def get_field_data(reader: gguf.GGUFReader, key: str) -> Any: @@ -57,7 +58,9 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new logger.debug(f'Removing {field.name}') continue - old_val = MetadataDetails(field.types[0], field.contents()) + val_type = field.types[0] + sub_type = field.types[-1] if val_type == gguf.GGUFValueType.ARRAY else None + old_val = MetadataDetails(val_type, field.contents(), sub_type=sub_type) val = new_metadata.get(field.name, old_val) if field.name in new_metadata: @@ -67,7 +70,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new logger.debug(f'Copying {field.name}') if val.value is not None: - writer.add_key_value(field.name, val.value, val.type) + writer.add_key_value(field.name, val.value, val.type, sub_type=sub_type if val.sub_type is None else val.sub_type) if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata: logger.debug('Adding chat template(s)') diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 35154e9b5..79f044d2a 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -31,6 +31,7 @@ class TensorNameMap: "model.embeddings", # rwkv7 "model.word_embeddings", # bailingmoe "language_model.model.embed_tokens", # llama4 + "encoder", # neobert ), # Token type embeddings @@ -68,7 +69,7 @@ class TensorNameMap: "output_layer", # chatglm "head", # rwkv "head.out", # wavtokenizer - "language_model.lm_head", # llama4 + "lm_head", # llama4 ), # Output norm @@ -91,7 +92,7 @@ class TensorNameMap: "rwkv.ln_out", # rwkv6 "model.ln_out", # rwkv7 "backbone.final_layer_norm", # wavtokenizer - "language_model.model.norm", # llama4 + "model.norm", # llama4 ), # Rope frequencies @@ -133,7 +134,8 @@ class TensorNameMap: "transformer.layers.{bid}.attn_norm", # openelm "rwkv.blocks.{bid}.ln1", # rwkv6 "model.layers.{bid}.ln1", # rwkv7 - "language_model.model.layers.{bid}.input_layernorm", # llama4 + "model.layers.{bid}.input_layernorm", # llama4 + "transformer_encoder.{bid}.attention_norm", # neobert ), # Attention norm 2 @@ -157,9 +159,11 @@ class TensorNameMap: "h.{bid}.attn.c_attn", # gpt2 "transformer.h.{bid}.mixer.Wqkv", # phi2 "encoder.layers.{bid}.attn.Wqkv", # nomic-bert + "encoder.layers.{bid}.mixer.Wqkv", # jina "model.layers.{bid}.self_attn.qkv_proj", # phi3 "encoder.layers.{bid}.self_attention.query_key_value", # chatglm "transformer.layers.{bid}.attn.qkv_proj", # openelm + "transformer_encoder.{bid}.qkv", # neobert ), # Attention query @@ -168,12 +172,13 @@ class TensorNameMap: "model.layers.{bid}.self_attn.q_proj_no_perm", # llama-custom "layers.{bid}.attention.wq", # llama-pth "encoder.layer.{bid}.attention.self.query", # bert + "transformer.layer.{bid}.attention.q_lin", # distillbert "transformer.h.{bid}.attn.q_proj", # gpt-j "model.layers.layers.{bid}.self_attn.q_proj", # plamo "model.layers.{bid}.attention.wq", # internlm2 "transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok "transformer.h.{bid}.attn.attention.q_proj", # exaone - "language_model.model.layers.{bid}.self_attn.q_proj", # llama4 + "model.layers.{bid}.self_attn.q_proj", # llama4 ), # Attention key @@ -182,13 +187,14 @@ class TensorNameMap: "model.layers.{bid}.self_attn.k_proj_no_perm", # llama-custom "layers.{bid}.attention.wk", # llama-pth "encoder.layer.{bid}.attention.self.key", # bert + "transformer.layer.{bid}.attention.k_lin", # distillbert "transformer.h.{bid}.attn.k_proj", # gpt-j "transformer.h.{bid}.attn.k", # refact "model.layers.layers.{bid}.self_attn.k_proj", # plamo "model.layers.{bid}.attention.wk", # internlm2 "transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok "transformer.h.{bid}.attn.attention.k_proj", # exaone - "language_model.model.layers.{bid}.self_attn.k_proj", # llama4 + "model.layers.{bid}.self_attn.k_proj", # llama4 ), # Attention value @@ -196,13 +202,14 @@ class TensorNameMap: "model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 phimoe "layers.{bid}.attention.wv", # llama-pth "encoder.layer.{bid}.attention.self.value", # bert + "transformer.layer.{bid}.attention.v_lin", # distillbert "transformer.h.{bid}.attn.v_proj", # gpt-j "transformer.h.{bid}.attn.v", # refact "model.layers.layers.{bid}.self_attn.v_proj", # plamo "model.layers.{bid}.attention.wv", # internlm2 "transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok "transformer.h.{bid}.attn.attention.v_proj", # exaone - "language_model.model.layers.{bid}.self_attn.v_proj", # llama4 + "model.layers.{bid}.self_attn.v_proj", # llama4 ), # Attention output @@ -216,6 +223,7 @@ class TensorNameMap: "model.layers.{bid}.self_attn.linear_attn", # deci "layers.{bid}.attention.wo", # llama-pth "encoder.layer.{bid}.attention.output.dense", # bert + "transformer.layer.{bid}.attention.out_lin", # distillbert "transformer.h.{bid}.attn.out_proj", # gpt-j "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon "model.layers.{bid}.self_attn.dense", # persimmon @@ -224,17 +232,20 @@ class TensorNameMap: "model.layers.layers.{bid}.self_attn.o_proj", # plamo "model.layers.{bid}.attention.wo", # internlm2 "encoder.layers.{bid}.attn.out_proj", # nomic-bert + "encoder.layers.{bid}.mixer.out_proj", # jina "transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok "transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx "encoder.layers.{bid}.self_attention.dense", # chatglm "transformer.layers.{bid}.attn.out_proj", # openelm "transformer.h.{bid}.attn.attention.out_proj", # exaone - "language_model.model.layers.{bid}.self_attn.o_proj", # llama4 + "model.layers.{bid}.self_attn.o_proj", # llama4 + "transformer_encoder.{bid}.wo", # neobert ), # Attention output norm MODEL_TENSOR.ATTN_OUT_NORM: ( "encoder.layer.{bid}.attention.output.LayerNorm", # bert + "transformer.layer.{bid}.sa_layer_norm", # distillbert "encoder.layers.{bid}.norm1", # nomic-bert "transformer.decoder_layer.{bid}.rms_norm_1", # Grok "transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx @@ -268,7 +279,8 @@ class TensorNameMap: "transformer.decoder_layer.{bid}.rms_norm_2", # Grok "encoder.layers.{bid}.post_attention_layernorm", # chatglm "transformer.layers.{bid}.ffn_norm", # openelm - "language_model.model.layers.{bid}.post_attention_layernorm", # llama4 + "model.layers.{bid}.post_attention_layernorm", # llama4 + "transformer_encoder.{bid}.ffn_norm", # neobert ), # Post feed-forward norm @@ -289,7 +301,8 @@ class TensorNameMap: "transformer.decoder_layer.{bid}.router", # Grok "transformer.blocks.{bid}.ffn.router.layer", # dbrx "model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe - "language_model.model.layers.{bid}.feed_forward.router", # llama4 + "model.layers.{bid}.feed_forward.router", # llama4 + "encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe ), MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( @@ -297,7 +310,7 @@ class TensorNameMap: ), MODEL_TENSOR.FFN_EXP_PROBS_B: ( - "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 + "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1 ), # Feed-forward up @@ -310,6 +323,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron olmo2 "layers.{bid}.feed_forward.w3", # llama-pth "encoder.layer.{bid}.intermediate.dense", # bert + "transformer.layer.{bid}.ffn.lin1", # distillbert "transformer.h.{bid}.mlp.fc_in", # gpt-j "transformer.h.{bid}.mlp.linear_3", # refact "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon @@ -322,12 +336,16 @@ class TensorNameMap: "model.layers.layers.{bid}.mlp.up_proj", # plamo "model.layers.{bid}.feed_forward.w3", # internlm2 "encoder.layers.{bid}.mlp.fc11", # nomic-bert + "encoder.layers.{bid}.mlp.fc1", # nomic-bert-moe "model.layers.{bid}.mlp.c_fc", # starcoder2 - "encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2 + "encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2 (split up/gate, no longer used) + "encoder.layer.{bid}.mlp.gated_layers", # jina-bert-v2 (GEGLU) + "encoder.layer.{bid}.mlp.up_gated_layer", # jina-v2-code (GEGLU) "model.layers.{bid}.residual_mlp.w3", # arctic "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm "transformer.h.{bid}.mlp.c_fc_1", # exaone - "language_model.model.layers.{bid}.feed_forward.up_proj", # llama4 + "model.layers.{bid}.feed_forward.up_proj", # llama4 + "transformer_encoder.{bid}.ffn.w12", # neobert ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -336,13 +354,14 @@ class TensorNameMap: "transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx "model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) "model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged) - "language_model.model.layers.{bid}.feed_forward.experts.up_proj", # llama4 + "model.layers.{bid}.feed_forward.experts.up_proj", # llama4 + "encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe ), MODEL_TENSOR.FFN_UP_SHEXP: ( - "model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe - "model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2 - "language_model.model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4 + "model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe + "model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2 + "model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4 ), # AWQ-activation gate @@ -359,26 +378,26 @@ class TensorNameMap: "model.layers.layers.{bid}.mlp.gate_proj", # plamo "model.layers.{bid}.feed_forward.w1", # internlm2 "encoder.layers.{bid}.mlp.fc12", # nomic-bert - "encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 + "encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 (split up/gate, no longer used) "transformer.h.{bid}.mlp.linear_1", # refact "model.layers.{bid}.residual_mlp.w1", # arctic "transformer.h.{bid}.mlp.c_fc_0", # exaone - "language_model.model.layers.{bid}.feed_forward.gate_proj", # llama4 + "model.layers.{bid}.feed_forward.gate_proj", # llama4 ), MODEL_TENSOR.FFN_GATE_EXP: ( - "layers.{bid}.feed_forward.experts.w1", # mixtral (merged) - "transformer.decoder_layer.{bid}.moe.linear", # Grok (merged) - "transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx - "model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) - "model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged) - "language_model.model.layers.{bid}.feed_forward.experts.gate_proj", # llama4 + "layers.{bid}.feed_forward.experts.w1", # mixtral (merged) + "transformer.decoder_layer.{bid}.moe.linear", # Grok (merged) + "transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx + "model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) + "model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged) + "model.layers.{bid}.feed_forward.experts.gate_proj", # llama4 ), MODEL_TENSOR.FFN_GATE_SHEXP: ( - "model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe - "model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2 - "language_model.model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4 + "model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe + "model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2 + "model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4 ), # Feed-forward down @@ -391,6 +410,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.down_proj", # llama-hf nemotron olmo2 "layers.{bid}.feed_forward.w2", # llama-pth "encoder.layer.{bid}.output.dense", # bert + "transformer.layer.{bid}.ffn.lin2", # distillbert "transformer.h.{bid}.mlp.fc_out", # gpt-j "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon "model.layers.{bid}.mlp.dense_4h_to_h", # persimmon @@ -407,7 +427,8 @@ class TensorNameMap: "encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2 "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm "model.layers.h.{bid}.mlp.c_proj", # exaone - "language_model.model.layers.{bid}.feed_forward.down_proj", # llama4 + "model.layers.{bid}.feed_forward.down_proj", # llama4 + "transformer_encoder.{bid}.ffn.w3", # neobert ), MODEL_TENSOR.FFN_DOWN_EXP: ( @@ -417,13 +438,15 @@ class TensorNameMap: "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged) "model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe "model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged) - "language_model.model.layers.{bid}.feed_forward.experts.down_proj", # llama4 + "model.layers.{bid}.feed_forward.experts.down_proj", # llama4 + "encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe ), MODEL_TENSOR.FFN_DOWN_SHEXP: ( - "model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe - "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2 - "language_model.model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4 + "model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe + "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2 + "model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4 + "model.layers.{bid}.shared_mlp.output_linear", # granitemoe ), MODEL_TENSOR.ATTN_Q_NORM: ( @@ -450,6 +473,7 @@ class TensorNameMap: MODEL_TENSOR.LAYER_OUT_NORM: ( "encoder.layer.{bid}.output.LayerNorm", # bert + "transformer.layer.{bid}.output_layer_norm", # distillbert "encoder.layers.{bid}.norm2", # nomic-bert "transformer.decoder_layer.{bid}.rms_norm_3", # Grok "encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2 @@ -677,6 +701,14 @@ class TensorNameMap: "model.layers.{bid}.self_attn.kv_b_proj", # deepseek2 ), + MODEL_TENSOR.ATTN_K_B: ( + "model.layers.{bid}.self_attn.k_b_proj", # deepseek2 + ), + + MODEL_TENSOR.ATTN_V_B: ( + "model.layers.{bid}.self_attn.v_b_proj", # deepseek2 + ), + MODEL_TENSOR.ATTN_Q_A_NORM: ( "model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2 ), @@ -807,11 +839,14 @@ class TensorNameMap: # TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg MODEL_TENSOR.ENC_OUTPUT_NORM: ( "encoder.final_layer_norm", # t5 + "layer_norm", # neobert ), MODEL_TENSOR.CLS: ( "classifier", # jina "classifier.dense", # roberta + "pre_classifier", # distillbert + "dense", # neobert ), MODEL_TENSOR.CLS_OUT: ( @@ -878,6 +913,295 @@ class TensorNameMap: MODEL_TENSOR.POSNET_ATTN_OUT: ( "backbone.posnet.{bid}.proj_out", # wavtokenizer ), + + ############################################################################# + ## Vision encoder + + MODEL_TENSOR.V_MMPROJ: ( + "multi_modal_projector.linear_{bid}", + "visual.merger.mlp.{bid}", # qwen2vl + ), + + MODEL_TENSOR.V_MMPROJ_FC: ( + "model.connector.modality_projection.proj", # SmolVLM + ), + + MODEL_TENSOR.V_MMPROJ_MLP: ( + "model.mm_projector.mlp.mlp.{bid}", + "vision_model.vision_adapter.mlp.fc{bid}", # llama 4 + "mlp1.{bid}", # InternVL + ), + + MODEL_TENSOR.V_MMPROJ_PEG: ( + "model.mm_projector.peg.peg.{bid}", + ), + + MODEL_TENSOR.V_ENC_EMBD_CLS: ( + "vision_tower.vision_model.embeddings.class_embedding", + "vision_model.class_embedding", # llama 4 + ), + + MODEL_TENSOR.V_ENC_EMBD_PATCH: ( + "vision_tower.vision_model.embeddings.patch_embedding", + "vpm.embeddings.patch_embedding", + "model.vision_model.embeddings.patch_embedding", # SmolVLM + "vision_tower.patch_conv", # pixtral + "vision_model.patch_embedding.linear", # llama 4 + "visual.patch_embed.proj", # qwen2vl + ), + + MODEL_TENSOR.V_ENC_EMBD_POS: ( + "vision_tower.vision_model.embeddings.position_embedding", + "vpm.embeddings.position_embedding", + "model.vision_model.embeddings.position_embedding", # SmolVLM + "vision_model.positional_embedding_vlm", # llama 4 + ), + + MODEL_TENSOR.V_ENC_ATTN_Q: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj", + "vpm.encoder.layers.{bid}.self_attn.q_proj", + "model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM + "vision_model.model.layers.{bid}.self_attn.q_proj", # llama4 + "vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral + "visual.blocks.{bid}.attn.q", # qwen2vl, generated + ), + + MODEL_TENSOR.V_ENC_ATTN_Q_NORM: ( + "vision_tower.vision_model.encoder.layers.{bid}.attn.q_norm", # InternVL + ), + + MODEL_TENSOR.V_ENC_ATTN_K: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj", + "vpm.encoder.layers.{bid}.self_attn.k_proj", + "model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM + "vision_model.model.layers.{bid}.self_attn.k_proj", # llama4 + "vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral + "visual.blocks.{bid}.attn.k", # qwen2vl, generated + ), + + MODEL_TENSOR.V_ENC_ATTN_K_NORM: ( + "vision_tower.vision_model.encoder.layers.{bid}.attn.k_norm", # InternVL + ), + + MODEL_TENSOR.V_ENC_ATTN_V: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj", + "vpm.encoder.layers.{bid}.self_attn.v_proj", + "model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM + "vision_model.model.layers.{bid}.self_attn.v_proj", # llama4 + "vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral + "visual.blocks.{bid}.attn.v", # qwen2vl, generated + ), + + MODEL_TENSOR.V_ENC_INPUT_NORM: ( + "vision_tower.vision_model.encoder.layers.{bid}.layer_norm1", + "vision_tower.vision_model.encoder.layers.{bid}.norm1", # InternVL + "vpm.encoder.layers.{bid}.layer_norm1", + "model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM + "vision_tower.transformer.layers.{bid}.attention_norm", # pixtral + "vision_model.model.layers.{bid}.input_layernorm", # llama4 + "visual.blocks.{bid}.norm1", # qwen2vl + ), + + MODEL_TENSOR.V_ENC_ATTN_O: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj", + "vision_tower.vision_model.encoder.layers.{bid}.attn.proj", # InternVL + "vpm.encoder.layers.{bid}.self_attn.out_proj", + "model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM + "vision_model.model.layers.{bid}.self_attn.o_proj", # llama4 + "vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral + "visual.blocks.{bid}.attn.proj", # qwen2vl + ), + + MODEL_TENSOR.V_ENC_POST_ATTN_NORM: ( + "vision_tower.vision_model.encoder.layers.{bid}.layer_norm2", + "vision_tower.vision_model.encoder.layers.{bid}.norm2", # InternVL + "vpm.encoder.layers.{bid}.layer_norm2", + "model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM + "vision_model.model.layers.{bid}.post_attention_layernorm", # llama4 + "vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral + "visual.blocks.{bid}.norm2", # qwen2vl + ), + + MODEL_TENSOR.V_ENC_FFN_UP: ( + "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1", + "vpm.encoder.layers.{bid}.mlp.fc1", + "model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3 + "vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral + "vision_model.model.layers.{bid}.mlp.fc1", # llama4 + "visual.blocks.{bid}.mlp.fc1", # qwen2vl + "visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl + ), + + MODEL_TENSOR.V_ENC_FFN_GATE: ( + "vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral + "visual.blocks.{bid}.mlp.gate_proj", # qwen2.5vl + ), + + MODEL_TENSOR.V_ENC_FFN_DOWN: ( + "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2", + "vpm.encoder.layers.{bid}.mlp.fc2", + "model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3 + "vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral + "vision_model.model.layers.{bid}.mlp.fc2", # llama4 + "visual.blocks.{bid}.mlp.fc2", # qwen2vl + "visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl + ), + + MODEL_TENSOR.V_LAYER_SCALE_1: ( + "vision_tower.vision_model.encoder.layers.{bid}.ls1", # InternVL + ), + + MODEL_TENSOR.V_LAYER_SCALE_2: ( + "vision_tower.vision_model.encoder.layers.{bid}.ls2", # InternVL + ), + + MODEL_TENSOR.V_PRE_NORM: ( + "vision_tower.vision_model.pre_layrnorm", + "vision_tower.ln_pre", # pixtral + "vision_model.layernorm_pre", # llama4 + ), + + MODEL_TENSOR.V_POST_NORM: ( + "vision_tower.vision_model.post_layernorm", + "model.vision_model.post_layernorm", # SmolVLM + "vision_model.layernorm_post", # llama4 + "visual.merger.ln_q", # qwen2vl + ), + + MODEL_TENSOR.V_MM_INP_PROJ: ( + "multi_modal_projector.mm_input_projection", + ), + + MODEL_TENSOR.V_MM_INP_NORM: ( + "multi_modal_projector.norm", + ), + + MODEL_TENSOR.V_MM_SOFT_EMB_NORM: ( + "multi_modal_projector.mm_soft_emb_norm", + ), + + MODEL_TENSOR.V_RESMPL_POS_EMBD_K: ( + "resampler.pos_embed_k", + ), + + MODEL_TENSOR.V_RESMPL_ATTN_Q: ( + "resampler.attn.in_proj_q", # tensor generated from resampler.attn.in_proj + ), + + MODEL_TENSOR.V_RESMPL_ATTN_K: ( + "resampler.attn.in_proj_k", # tensor generated from resampler.attn.in_proj + ), + + MODEL_TENSOR.V_RESMPL_ATTN_V: ( + "resampler.attn.in_proj_v", # tensor generated from resampler.attn.in_proj + ), + + MODEL_TENSOR.V_RESMPL_ATTN_OUT: ( + "resampler.attn.out_proj", + ), + + MODEL_TENSOR.V_RESMPL_KV: ( + "resampler.kv_proj", + ), + + MODEL_TENSOR.V_RESMPL_POST_NORM: ( + "resampler.ln_post", + ), + + MODEL_TENSOR.V_RESMPL_KV_NORM: ( + "resampler.ln_kv", + ), + + MODEL_TENSOR.V_RESMPL_Q_NORM: ( + "resampler.ln_q", + ), + + MODEL_TENSOR.V_RESMPL_PROJ: ( + "resampler.proj", + ), + + MODEL_TENSOR.V_RESMPL_QUERY: ( + "resampler.query", + ), + + MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: ( + "v.token_embd.img_break", # for pixtral, this is a generated vector + ), + + MODEL_TENSOR.V_MM_PATCH_MERGER: ( + "multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1 + ), + + # audio (mtmd) + + MODEL_TENSOR.A_ENC_EMBD_POS: ( + "audio_tower.embed_positions", # ultravox + ), + + MODEL_TENSOR.A_ENC_CONV1D: ( + "audio_tower.conv{bid}", # ultravox + ), + + MODEL_TENSOR.A_PRE_NORM: (), + + MODEL_TENSOR.A_POST_NORM: ( + "audio_tower.layer_norm", # ultravox + "audio_tower.ln_post", # qwen2omni + ), + + MODEL_TENSOR.A_ENC_ATTN_Q: ( + "audio_tower.layers.{bid}.self_attn.q_proj", # ultravox + ), + + MODEL_TENSOR.A_ENC_ATTN_K: ( + "audio_tower.layers.{bid}.self_attn.k_proj", # ultravox + ), + + MODEL_TENSOR.A_ENC_ATTN_V: ( + "audio_tower.layers.{bid}.self_attn.v_proj", # ultravox + ), + + MODEL_TENSOR.A_ENC_INPUT_NORM: ( + "audio_tower.layers.{bid}.self_attn_layer_norm", # ultravox + ), + + MODEL_TENSOR.A_ENC_OUTPUT: ( + "audio_tower.layers.{bid}.self_attn.out_proj", # ultravox + ), + + MODEL_TENSOR.A_ENC_OUTPUT_NORM: ( + "audio_tower.layers.{bid}.final_layer_norm", # ultravox + ), + + MODEL_TENSOR.A_ENC_FFN_UP: ( + "audio_tower.layers.{bid}.fc1", # ultravox + ), + + MODEL_TENSOR.A_ENC_FFN_GATE: (), + + MODEL_TENSOR.A_ENC_FFN_DOWN: ( + "audio_tower.layers.{bid}.fc2", # ultravox + ), + + # note: some tensors below has "audio." pseudo-prefix, to prevent conflicts with vision tensors + # this prefix is added in the conversion code in modify_tensors() + + MODEL_TENSOR.A_MMPROJ: ( + "audio.multi_modal_projector.linear_{bid}", # ultravox + ), + + MODEL_TENSOR.A_MMPROJ_FC: ( + "audio.multi_modal_projector.linear", # qwen2audio + "audio_tower.proj", # qwen2omni + ), + + MODEL_TENSOR.A_MM_NORM_PRE: ( + "audio.multi_modal_projector.ln_pre", # ultravox + ), + + MODEL_TENSOR.A_MM_NORM_MID: ( + "audio.multi_modal_projector.ln_mid", # ultravox + ), } # architecture-specific block mappings diff --git a/gguf-py/gguf/utility.py b/gguf-py/gguf/utility.py index e5251aef8..00adcbc93 100644 --- a/gguf-py/gguf/utility.py +++ b/gguf-py/gguf/utility.py @@ -231,7 +231,7 @@ class SafetensorRemote: response.raise_for_status() # Get raw byte data - return response.content[:size] + return response.content[slice(size if size > -1 else None)] @classmethod def check_file_exist(cls, url: str) -> bool: diff --git a/gguf-py/pyproject.toml b/gguf-py/pyproject.toml index d214e8720..f11351cba 100644 --- a/gguf-py/pyproject.toml +++ b/gguf-py/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "gguf" -version = "0.16.0" +version = "0.17.0" description = "Read and write ML models in GGUF for GGML" authors = ["GGML "] packages = [ @@ -23,16 +23,21 @@ numpy = ">=1.17" tqdm = ">=4.27" pyyaml = ">=5.1" sentencepiece = ">=0.1.98,<=0.2.0" +PySide6 = { version = "^6.9", python = ">=3.9,<3.14", optional = true } [tool.poetry.dev-dependencies] pytest = "^5.2" +[tool.poetry.extras] +gui = ["PySide6"] + [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" [tool.poetry.scripts] -gguf-convert-endian = "gguf.scripts:gguf_convert_endian_entrypoint" -gguf-dump = "gguf.scripts:gguf_dump_entrypoint" -gguf-set-metadata = "gguf.scripts:gguf_set_metadata_entrypoint" -gguf-new-metadata = "gguf.scripts:gguf_new_metadata_entrypoint" +gguf-convert-endian = "gguf.scripts.gguf_convert_endian:main" +gguf-dump = "gguf.scripts.gguf_dump:main" +gguf-set-metadata = "gguf.scripts.gguf_set_metadata:main" +gguf-new-metadata = "gguf.scripts.gguf_new_metadata:main" +gguf-editor-gui = "gguf.scripts.gguf_editor_gui:main" diff --git a/grammars/README.md b/grammars/README.md index 935213f5c..a63198b5a 100644 --- a/grammars/README.md +++ b/grammars/README.md @@ -1,6 +1,6 @@ # GBNF Guide -GBNF (GGML BNF) is a format for defining [formal grammars](https://en.wikipedia.org/wiki/Formal_grammar) to constrain model outputs in `llama.cpp`. For example, you can use it to force the model to generate valid JSON, or speak only in emojis. GBNF grammars are supported in various ways in `examples/main` and `examples/server`. +GBNF (GGML BNF) is a format for defining [formal grammars](https://en.wikipedia.org/wiki/Formal_grammar) to constrain model outputs in `llama.cpp`. For example, you can use it to force the model to generate valid JSON, or speak only in emojis. GBNF grammars are supported in various ways in `tools/main` and `tools/server`. ## Background @@ -110,21 +110,21 @@ While semantically correct, the syntax `x? x? x?.... x?` (with N repetitions) ma You can use GBNF grammars: -- In [llama-server](../examples/server)'s completion endpoints, passed as the `grammar` body field -- In [llama-cli](../examples/main), passed as the `--grammar` & `--grammar-file` flags -- With [llama-gbnf-validator](../examples/gbnf-validator) tool, to test them against strings. +- In [llama-server](../tools/server)'s completion endpoints, passed as the `grammar` body field +- In [llama-cli](../tools/main), passed as the `--grammar` & `--grammar-file` flags +- With [test-gbnf-validator](../tests/test-gbnf-validator.cpp), to test them against strings. ## JSON Schemas → GBNF `llama.cpp` supports converting a subset of https://json-schema.org/ to GBNF grammars: -- In [llama-server](../examples/server): +- In [llama-server](../tools/server): - For any completion endpoints, passed as the `json_schema` body field - For the `/chat/completions` endpoint, passed inside the `response_format` body field (e.g. `{"type", "json_object", "schema": {"items": {}}}` or `{ type: "json_schema", json_schema: {"schema": ...} }`) -- In [llama-cli](../examples/main), passed as the `--json` / `-j` flag +- In [llama-cli](../tools/main), passed as the `--json` / `-j` flag - To convert to a grammar ahead of time: - in CLI, with [examples/json_schema_to_grammar.py](../examples/json_schema_to_grammar.py) - - in JavaScript with [json-schema-to-grammar.mjs](../examples/server/public_legacy/json-schema-to-grammar.mjs) (this is used by the [server](../examples/server)'s Web UI) + - in JavaScript with [json-schema-to-grammar.mjs](../tools/server/public_legacy/json-schema-to-grammar.mjs) (this is used by the [server](../tools/server)'s Web UI) Take a look at [tests](../tests/test-json-schema-to-grammar.cpp) to see which features are likely supported (you'll also find usage examples in https://github.com/ggml-org/llama.cpp/pull/5978, https://github.com/ggml-org/llama.cpp/pull/6659 & https://github.com/ggml-org/llama.cpp/pull/6555). diff --git a/include/llama.h b/include/llama.h index 5fd99f4c8..635508b10 100644 --- a/include/llama.h +++ b/include/llama.h @@ -4,6 +4,7 @@ #include "ggml.h" #include "ggml-cpu.h" #include "ggml-backend.h" +#include "ggml-opt.h" #include #include @@ -60,7 +61,10 @@ extern "C" { struct llama_model; struct llama_context; struct llama_sampler; - struct llama_kv_cache; + + typedef struct llama_memory_i * llama_memory_t; + + struct llama_kv_cache; // DEPRECATED (use llama_memory instead) typedef int32_t llama_pos; typedef int32_t llama_token; @@ -111,6 +115,8 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, + LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, + LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, }; enum llama_rope_type { @@ -237,18 +243,21 @@ extern "C" { typedef bool (*llama_progress_callback)(float progress, void * user_data); - // Input data for llama_decode + // Input data for llama_encode/llama_decode // A llama_batch object can contain input about one or many sequences // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens // // - token : the token ids of the input (used when embd is NULL) // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) // - pos : the positions of the respective token in the sequence - // (if set to NULL, the token position will be tracked automatically by llama_decode) + // (if set to NULL, the token position will be tracked automatically by llama_encode/llama_decode) // - seq_id : the sequence to which the respective token belongs // (if set to NULL, the sequence ID will be assumed to be 0) // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output - // (if set to NULL, only the logits for last token will be returned) + // (if set to NULL: + // - if embeddings: all tokens are output + // - if not: only the last token is output + // ) // typedef struct llama_batch { int32_t n_tokens; @@ -258,7 +267,7 @@ extern "C" { llama_pos * pos; int32_t * n_seq_id; llama_seq_id ** seq_id; - int8_t * logits; // TODO: rename this to "output" + int8_t * logits; // TODO: rename this to "output" } llama_batch; enum llama_model_kv_override_type { @@ -342,7 +351,7 @@ extern "C" { float yarn_beta_fast; // YaRN low correction dim float yarn_beta_slow; // YaRN high correction dim uint32_t yarn_orig_ctx; // YaRN original context size - float defrag_thold; // defragment the KV cache if holes/size > thold, < 0 disabled (default) + float defrag_thold; // defragment the KV cache if holes/size > thold, <= 0 disabled (default) ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; @@ -350,34 +359,37 @@ extern "C" { enum ggml_type type_k; // data type for K cache [EXPERIMENTAL] enum ggml_type type_v; // data type for V cache [EXPERIMENTAL] - // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. - // TODO: move at the end of the struct - bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) - bool embeddings; // if true, extract embeddings (together with logits) - bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU - bool flash_attn; // whether to use flash attention [EXPERIMENTAL] - bool no_perf; // whether to measure performance timings - // Abort callback // if it returns true, execution of llama_decode() will be aborted // currently works only with CPU execution ggml_abort_callback abort_callback; void * abort_callback_data; + + // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. + bool embeddings; // if true, extract embeddings (together with logits) + bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU + bool flash_attn; // use flash attention [EXPERIMENTAL] + bool no_perf; // measure performance timings + bool op_offload; // offload host tensor operations to device + bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) + // NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases + // ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573 }; // model quantization parameters typedef struct llama_model_quantize_params { - int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() - enum llama_ftype ftype; // quantize to this llama_ftype - enum ggml_type output_tensor_type; // output tensor type - enum ggml_type token_embedding_type; // token embeddings tensor type - bool allow_requantize; // allow quantizing non-f32/f16 tensors - bool quantize_output_tensor; // quantize output.weight - bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored - bool pure; // quantize all tensors to the default type - bool keep_split; // quantize to the same number of shards - void * imatrix; // pointer to importance matrix data - void * kv_overrides; // pointer to vector containing overrides + int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() + enum llama_ftype ftype; // quantize to this llama_ftype + enum ggml_type output_tensor_type; // output tensor type + enum ggml_type token_embedding_type; // token embeddings tensor type + bool allow_requantize; // allow quantizing non-f32/f16 tensors + bool quantize_output_tensor; // quantize output.weight + bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored + bool pure; // quantize all tensors to the default type + bool keep_split; // quantize to the same number of shards + void * imatrix; // pointer to importance matrix data + void * kv_overrides; // pointer to vector containing overrides + void * tensor_types; // pointer to vector containing tensor types } llama_model_quantize_params; typedef struct llama_logit_bias { @@ -443,6 +455,10 @@ extern "C" { size_t n_paths, struct llama_model_params params); + LLAMA_API void llama_model_save_to_file( + const struct llama_model * model, + const char * path_model); + DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model), "use llama_model_free instead"); @@ -463,6 +479,7 @@ extern "C" { LLAMA_API int64_t llama_time_us(void); LLAMA_API size_t llama_max_devices(void); + LLAMA_API size_t llama_max_parallel_sequences(void); LLAMA_API bool llama_supports_mmap (void); LLAMA_API bool llama_supports_mlock (void); @@ -482,9 +499,11 @@ extern "C" { DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead"); LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); - LLAMA_API struct llama_kv_cache * llama_get_kv_self ( struct llama_context * ctx); + LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx); LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type + DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead"); + LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model); @@ -493,10 +512,18 @@ extern "C" { LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model); LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model); // Get the model's RoPE frequency scaling factor LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); + // Returns the number of classifier outputs (only valid for classifier models) + // Undefined behavior for non-classifier models + LLAMA_API uint32_t llama_model_n_cls_out(const struct llama_model * model); + + // Returns label of classifier output by index ( 1` + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) + LLAMA_API void llama_memory_seq_div( + llama_memory_t mem, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d); + + // Returns the smallest position present in the memory for the specified sequence + // This is typically non-zero only for SWA caches + // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory + // Return -1 if the sequence is empty + LLAMA_API llama_pos llama_memory_seq_pos_min( + llama_memory_t mem, + llama_seq_id seq_id); + + // Returns the largest position present in the memory for the specified sequence + // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the memory + // Return -1 if the sequence is empty + LLAMA_API llama_pos llama_memory_seq_pos_max( + llama_memory_t mem, + llama_seq_id seq_id); + + // Check if the memory supports shifting + LLAMA_API bool llama_memory_can_shift(llama_memory_t mem); + + // + // KV cache for self-attention (TODO: deprecate in favor of llama_memory) + // + + // Returns the number of tokens in the KV cache (slow, use only for debug) + // If a KV cell has multiple sequences assigned to it, it will be counted multiple times + DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx), + "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)"); + + // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) + DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx), + "Use llama_kv_self_seq_pos_max() and llama_kv_self_seq_pos_min() instead (https://github.com/ggml-org/llama.cpp/issues/13793)"); + + // Clear the KV cache - both cell info is erased and KV data is zeroed + DEPRECATED(LLAMA_API void llama_kv_self_clear( + struct llama_context * ctx), + "Use llama_memory_clear() instead"); + + // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) + // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails + // seq_id < 0 : match any sequence + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) + DEPRECATED(LLAMA_API bool llama_kv_self_seq_rm( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, - llama_pos p1); + llama_pos p1), + "Use llama_memory_seq_rm() instead"); // Copy all tokens that belong to the specified sequence to another sequence // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_self_seq_cp( + DEPRECATED(LLAMA_API void llama_kv_self_seq_cp( struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, - llama_pos p1); + llama_pos p1), + "Use llama_memory_seq_cp() instead"); // Removes all tokens that do not belong to the specified sequence - LLAMA_API void llama_kv_self_seq_keep( + DEPRECATED(LLAMA_API void llama_kv_self_seq_keep( struct llama_context * ctx, - llama_seq_id seq_id); + llama_seq_id seq_id), + "Use llama_memory_seq_keep() instead"); // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // If the KV cache is RoPEd, the KV data is updated accordingly: // - lazily on next llama_decode() - // - explicitly with llama_kv_self_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_self_seq_add( - struct llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta); - - // Integer division of the positions by factor of `d > 1` - // If the KV cache is RoPEd, the KV data is updated accordingly: - // - lazily on next llama_decode() - // - explicitly with llama_kv_self_update() - // p0 < 0 : [0, p1] - // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_self_seq_div( - struct llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - int d); - - // Returns the largest position present in the KV cache for the specified sequence - LLAMA_API llama_pos llama_kv_self_seq_pos_max( - struct llama_context * ctx, - llama_seq_id seq_id); - - // Defragment the KV cache - // This will be applied: - // - lazily on next llama_decode() - // - explicitly with llama_kv_self_update() - LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx); - - // Check if the context supports KV cache shifting - LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx); - - // Apply the KV cache updates (such as K-shifts, defragmentation, etc.) - LLAMA_API void llama_kv_self_update(struct llama_context * ctx); - - DEPRECATED(LLAMA_API void llama_kv_cache_clear( - struct llama_context * ctx), - "use llama_kv_self_clear instead"); - - DEPRECATED(LLAMA_API bool llama_kv_cache_seq_rm( - struct llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1), - "use llama_kv_self_seq_rm instead"); - - DEPRECATED(LLAMA_API void llama_kv_cache_seq_cp( - struct llama_context * ctx, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1), - "use llama_kv_self_seq_cp instead"); - - DEPRECATED(LLAMA_API void llama_kv_cache_seq_keep( - struct llama_context * ctx, - llama_seq_id seq_id), - "use llama_kv_self_seq_keep instead"); - - DEPRECATED(LLAMA_API void llama_kv_cache_seq_add( + DEPRECATED(LLAMA_API void llama_kv_self_seq_add( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta), - "use llama_kv_self_seq_add instead"); + "Use llama_memory_seq_add() instead"); - DEPRECATED(LLAMA_API void llama_kv_cache_seq_div( + // Integer division of the positions by factor of `d > 1` + // If the KV cache is RoPEd, the KV data is updated accordingly: + // - lazily on next llama_decode() + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) + DEPRECATED(void llama_kv_self_seq_div( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d), - "use llama_kv_self_seq_div instead"); + "Use llama_memory_seq_div() instead"); - DEPRECATED(LLAMA_API llama_pos llama_kv_cache_seq_pos_max( + // Returns the smallest position present in the KV cache for the specified sequence + // This is typically non-zero only for SWA caches + // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache + // Return -1 if the sequence is empty + DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_min( struct llama_context * ctx, llama_seq_id seq_id), - "use llama_kv_self_seq_pos_max instead"); + "Use llama_memory_seq_pos_min() instead"); - DEPRECATED(LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx), - "use llama_kv_self_defrag instead"); + // Returns the largest position present in the KV cache for the specified sequence + // Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache + // Return -1 if the sequence is empty + DEPRECATED(LLAMA_API llama_pos llama_kv_self_seq_pos_max( + struct llama_context * ctx, + llama_seq_id seq_id), + "Use llama_memory_seq_pos_max() instead"); - DEPRECATED(LLAMA_API bool llama_kv_cache_can_shift(const struct llama_context * ctx), - "use llama_kv_self_can_shift instead"); + // Defragment the KV cache + // This will be applied: + // - lazily on next llama_decode() + DEPRECATED(LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx), + "simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'"); - DEPRECATED(LLAMA_API void llama_kv_cache_update(struct llama_context * ctx), - "use llama_kv_self_update instead"); + // Check if the context supports KV cache shifting + DEPRECATED(LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx), + "use llama_memory_can_shift() instead"); + // Apply the KV cache updates (such as K-shifts, defragmentation, etc.) + DEPRECATED(LLAMA_API void llama_kv_self_update(struct llama_context * ctx), + "simply remove this call, updates are applied lazily on the next llama_decode()"); // // State / sessions // // Returns the *actual* size in bytes of the state - // (logits, embedding and kv_cache) + // (logits, embedding and memory) // Only use when saving the state, not when restoring it, otherwise the size may be too small. LLAMA_API size_t llama_state_get_size(struct llama_context * ctx); LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx), @@ -856,12 +863,12 @@ extern "C" { size_t n_token_count), "use llama_state_save_file instead"); - // Get the exact size needed to copy the KV cache of a single sequence + // Get the exact size needed to copy the state of a single sequence LLAMA_API size_t llama_state_seq_get_size( struct llama_context * ctx, llama_seq_id seq_id); - // Copy the KV cache of a single sequence into the specified buffer + // Copy the state of a single sequence into the specified buffer LLAMA_API size_t llama_state_seq_get_data( struct llama_context * ctx, uint8_t * dst, @@ -922,18 +929,26 @@ extern "C" { // Frees a batch of tokens allocated with llama_batch_init() LLAMA_API void llama_batch_free(struct llama_batch batch); - // Processes a batch of tokens with the ecoder part of the encoder-decoder model. - // Stores the encoder output internally for later use by the decoder cross-attention layers. + // Process a batch of tokens. + // In contrast to llama_decode() - this call does not use KV cache. + // For encode-decoder contexts, processes the batch using the encoder. + // Can store the encoder output internally for later use by the decoder's cross-attention layers. // 0 - success - // < 0 - error. the KV cache state is restored to the state before this call + // < 0 - error. the memory state is restored to the state before this call LLAMA_API int32_t llama_encode( struct llama_context * ctx, struct llama_batch batch); + // Process a batch of tokens. + // Requires the context to have a memory. + // For encode-decoder contexts, processes the batch using the decoder. // Positive return values does not mean a fatal error, but rather a warning. - // 0 - success - // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) - // < 0 - error. the KV cache state is restored to the state before this call + // Upon non-zero return values, the memory state is restored to the state before this call + // 0 - success + // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) + // 2 - aborted + // -1 - invalid input batch + // < -1 - error LLAMA_API int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch); @@ -949,8 +964,8 @@ extern "C" { // Get the number of threads used for prompt and batch processing (multiple token). LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx); - // Set whether the model is in embeddings mode or not - // If true, embeddings will be returned but logits will not + // Set whether the context outputs embeddings or not + // TODO: rename to avoid confusion with llama_get_embeddings() LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings); // Set whether to use causal attention or not @@ -999,7 +1014,7 @@ extern "C" { // Get the embeddings for a sequence id // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE - // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence + // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[n_cls_out] with the rank(s) of the sequence // otherwise: float[n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); @@ -1230,6 +1245,7 @@ extern "C" { "will be removed in the future (see https://github.com/ggml-org/llama.cpp/pull/9896#discussion_r1800920915)"); /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 + /// Setting k <= 0 makes this a noop LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k); /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 @@ -1425,6 +1441,37 @@ extern "C" { LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain); LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain); + // + // training + // + + // function that returns whether or not a given tensor contains trainable parameters + typedef bool (*llama_opt_param_filter)(const struct ggml_tensor * tensor, void * userdata); + + // always returns true + LLAMA_API bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata); + + struct llama_opt_params { + uint32_t n_ctx_train; // assumed context size post training, use context size specified in llama_context if 0 + + llama_opt_param_filter param_filter; // callback for determining which tensors contain trainable parameters + void * param_filter_ud; // userdata for determining which tensors contain trainable parameters + + ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters + void * get_opt_pars_ud; // userdata for calculating optimizer parameters + }; + + LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params); + + LLAMA_API void llama_opt_epoch( + struct llama_context * lctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result_train, + ggml_opt_result_t result_eval, + int64_t idata_split, + ggml_opt_epoch_callback callback_train, + ggml_opt_epoch_callback callback_eval); + #ifdef __cplusplus } #endif diff --git a/models/ggml-vocab-bert-bge.gguf.inp b/models/ggml-vocab-bert-bge.gguf.inp index 9baf7d77a..86b934e40 100644 --- a/models/ggml-vocab-bert-bge.gguf.inp +++ b/models/ggml-vocab-bert-bge.gguf.inp @@ -1,6 +1,6 @@ ied 4 ½ months __ggml_vocab_test__ -Führer +Äpfel __ggml_vocab_test__ __ggml_vocab_test__ diff --git a/models/ggml-vocab-bert-bge.gguf.out b/models/ggml-vocab-bert-bge.gguf.out index a62566ce7..b1c49672f 100644 --- a/models/ggml-vocab-bert-bge.gguf.out +++ b/models/ggml-vocab-bert-bge.gguf.out @@ -1,5 +1,5 @@ 29464 2094 1018 1092 2706 - 11865 17875 + 9706 7959 2140 diff --git a/models/ggml-vocab-chameleon.gguf.inp b/models/ggml-vocab-chameleon.gguf.inp deleted file mode 100644 index 9baf7d77a..000000000 --- a/models/ggml-vocab-chameleon.gguf.inp +++ /dev/null @@ -1,112 +0,0 @@ -ied 4 ½ months -__ggml_vocab_test__ -Führer -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - - -__ggml_vocab_test__ - - - -__ggml_vocab_test__ - - - - -__ggml_vocab_test__ - - -__ggml_vocab_test__ -Hello world -__ggml_vocab_test__ - Hello world -__ggml_vocab_test__ -Hello World -__ggml_vocab_test__ - Hello World -__ggml_vocab_test__ - Hello World! -__ggml_vocab_test__ -Hello, world! -__ggml_vocab_test__ - Hello, world! -__ggml_vocab_test__ - this is 🦙.cpp -__ggml_vocab_test__ -w048 7tuijk dsdfhu -__ggml_vocab_test__ -нещо на Български -__ggml_vocab_test__ -កាន់តែពិសេសអាចខលចេញ -__ggml_vocab_test__ -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) -__ggml_vocab_test__ -Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello - Hello -__ggml_vocab_test__ - ( -__ggml_vocab_test__ - - = -__ggml_vocab_test__ -' era -__ggml_vocab_test__ -Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ -__ggml_vocab_test__ -!!!!!! -__ggml_vocab_test__ -3 -__ggml_vocab_test__ -33 -__ggml_vocab_test__ -333 -__ggml_vocab_test__ -3333 -__ggml_vocab_test__ -33333 -__ggml_vocab_test__ -333333 -__ggml_vocab_test__ -3333333 -__ggml_vocab_test__ -33333333 -__ggml_vocab_test__ -333333333 -__ggml_vocab_test__ -Cửa Việt -__ggml_vocab_test__ - discards -__ggml_vocab_test__ - - - - - - - - - - - -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL -__ggml_vocab_test__ diff --git a/models/ggml-vocab-chameleon.gguf.out b/models/ggml-vocab-chameleon.gguf.out deleted file mode 100644 index 7c5413fee..000000000 --- a/models/ggml-vocab-chameleon.gguf.out +++ /dev/null @@ -1,46 +0,0 @@ - 17245 16604 16403 16604 33583 18355 - 16421 51153 - - 16604 - 16650 - 16650 16604 - 16581 - 16582 - 16582 16582 - 16582 16582 16582 - 16581 16582 - 31596 17394 - 34926 17394 - 31596 18671 - 34926 18671 - 34926 18671 16384 - 31596 16395 17394 16384 - 34926 16395 17394 16384 - 16811 16704 20410 16483 16631 16397 52854 - 16470 16399 16403 16407 16604 16406 35764 38185 51595 22592 26639 - 29479 23955 17012 20103 25527 27670 17408 19005 21473 24774 - 54254 42231 48084 29409 16617 61889 29409 16608 21954 16628 21954 16499 58445 29409 16607 58445 21954 16479 42231 21954 16611 21954 16607 21954 16633 21954 16611 29409 16607 21954 16615 - 52351 16604 16391 25825 16392 23686 16498 39161 18885 16618 16488 30853 16604 16391 54124 17153 25134 16656 18476 26169 16895 16392 62193 16611 16604 16391 24664 17153 57169 16721 16872 17073 17304 28729 16392 - 31596 - 34926 - 16650 31596 - 16650 34926 - 16696 31596 - 16696 31596 16582 16696 31596 - 16604 16391 - 16582 16604 16412 - 16390 22623 - 31596 16395 16712 16390 16828 16384 17674 16769 16732 23686 16607 16604 16414 24427 16623 41809 16495 28999 36469 45292 30197 16400 16402 16400 16403 16400 16404 16400 43969 65211 16636 - 16384 16384 16384 16384 16384 16384 - 16402 - 16402 16402 - 16402 16402 16402 - 16402 16402 16402 16402 - 16402 16402 16402 16402 16402 - 16402 16402 16402 16402 16402 16402 - 16402 16402 16402 16402 16402 16402 16402 - 16402 16402 16402 16402 16402 16402 16402 16402 - 16402 16402 16402 16402 16402 16402 16402 16402 16402 - 16418 19038 16639 16448 24315 33727 16467 - 18765 17981 - 16582 16604 16582 16582 16604 16582 16582 16582 16604 16581 16604 16581 16581 16604 16581 16582 16650 16582 16650 16604 16582 16696 16582 16696 16604 16582 52351 16604 16391 25825 16392 23686 16498 39161 18885 16618 16488 30853 16604 16391 54124 17153 25134 16656 18476 26169 16895 16392 62193 16611 20410 16483 16631 18885 16483 16631 16604 16402 16604 16402 16402 16604 16402 16402 16402 16604 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16402 16604 16402 16402 16402 16402 16402 16402 16402 16402 16604 16402 16397 16402 16604 16402 16397 16397 16402 16604 16402 16397 16397 16397 16402 16604 54254 42231 48084 29409 16617 61889 29409 16608 21954 16628 21954 16499 58445 29409 16607 58445 21954 16479 42231 21954 16611 27683 16607 16604 16414 24427 16623 41809 16495 28999 36469 45292 30197 16400 16402 16400 16403 16400 16404 16400 43969 65211 16636 16604 16396 16396 16396 16396 16396 16396 16412 16412 16412 16412 16412 16412 16412 27268 23955 17012 20103 25527 27670 17408 19005 21473 24774 16604 16390 16390 16390 16390 16390 16390 16447 16447 16447 16447 16447 16447 16447 16385 16385 16385 16385 16397 16397 16397 16397 16397 16397 16384 16384 16384 16384 16384 16384 16414 16414 16414 16414 16414 16414 16687 16390 16690 16992 16604 16390 61797 16733 16390 16466 16986 16395 16604 16390 17879 16732 17811 16414 16604 16390 16428 16804 17811 16687 16390 16683 17190 16728 16395 16604 16390 16419 16732 16945 16991 25251 16414 17119 16390 38127 16641 16390 16459 16427 diff --git a/models/ggml-vocab-command-r.gguf.inp b/models/ggml-vocab-command-r.gguf.inp index 9baf7d77a..86b934e40 100644 --- a/models/ggml-vocab-command-r.gguf.inp +++ b/models/ggml-vocab-command-r.gguf.inp @@ -1,6 +1,6 @@ ied 4 ½ months __ggml_vocab_test__ -Führer +Äpfel __ggml_vocab_test__ __ggml_vocab_test__ diff --git a/models/ggml-vocab-command-r.gguf.out b/models/ggml-vocab-command-r.gguf.out index 3f6b41888..0e3af72eb 100644 --- a/models/ggml-vocab-command-r.gguf.out +++ b/models/ggml-vocab-command-r.gguf.out @@ -1,5 +1,5 @@ 2536 228 27 228 22957 6983 - 45 193433 + 90711 87 20910 228 1667 diff --git a/models/ggml-vocab-deepseek-coder.gguf.inp b/models/ggml-vocab-deepseek-coder.gguf.inp index 9baf7d77a..86b934e40 100644 --- a/models/ggml-vocab-deepseek-coder.gguf.inp +++ b/models/ggml-vocab-deepseek-coder.gguf.inp @@ -1,6 +1,6 @@ ied 4 ½ months __ggml_vocab_test__ -Führer +Äpfel __ggml_vocab_test__ __ggml_vocab_test__ diff --git a/models/ggml-vocab-deepseek-coder.gguf.out b/models/ggml-vocab-deepseek-coder.gguf.out index 52c4111a1..ef6bc5b8a 100644 --- a/models/ggml-vocab-deepseek-coder.gguf.out +++ b/models/ggml-vocab-deepseek-coder.gguf.out @@ -1,5 +1,5 @@ 1050 207 19 207 19192 4217 - 37 32009 71 6247 + 125 213 26862 282 207 243 diff --git a/models/ggml-vocab-deepseek-llm.gguf.inp b/models/ggml-vocab-deepseek-llm.gguf.inp index 9baf7d77a..86b934e40 100644 --- a/models/ggml-vocab-deepseek-llm.gguf.inp +++ b/models/ggml-vocab-deepseek-llm.gguf.inp @@ -1,6 +1,6 @@ ied 4 ½ months __ggml_vocab_test__ -Führer +Äpfel __ggml_vocab_test__ __ggml_vocab_test__ diff --git a/models/ggml-vocab-deepseek-llm.gguf.out b/models/ggml-vocab-deepseek-llm.gguf.out index 0191b7a11..f9d49c9af 100644 --- a/models/ggml-vocab-deepseek-llm.gguf.out +++ b/models/ggml-vocab-deepseek-llm.gguf.out @@ -1,5 +1,5 @@ 1052 207 19 207 19109 4223 - 37 100014 71 6245 + 82077 26723 282 207 243 diff --git a/models/ggml-vocab-deepseek-r1-qwen.gguf.inp b/models/ggml-vocab-deepseek-r1-qwen.gguf.inp deleted file mode 100644 index 9baf7d77a..000000000 --- a/models/ggml-vocab-deepseek-r1-qwen.gguf.inp +++ /dev/null @@ -1,112 +0,0 @@ -ied 4 ½ months -__ggml_vocab_test__ -Führer -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - - -__ggml_vocab_test__ - - - -__ggml_vocab_test__ - - - - -__ggml_vocab_test__ - - -__ggml_vocab_test__ -Hello world -__ggml_vocab_test__ - Hello world -__ggml_vocab_test__ -Hello World -__ggml_vocab_test__ - Hello World -__ggml_vocab_test__ - Hello World! -__ggml_vocab_test__ -Hello, world! -__ggml_vocab_test__ - Hello, world! -__ggml_vocab_test__ - this is 🦙.cpp -__ggml_vocab_test__ -w048 7tuijk dsdfhu -__ggml_vocab_test__ -нещо на Български -__ggml_vocab_test__ -កាន់តែពិសេសអាចខលចេញ -__ggml_vocab_test__ -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) -__ggml_vocab_test__ -Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello - Hello -__ggml_vocab_test__ - ( -__ggml_vocab_test__ - - = -__ggml_vocab_test__ -' era -__ggml_vocab_test__ -Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ -__ggml_vocab_test__ -!!!!!! -__ggml_vocab_test__ -3 -__ggml_vocab_test__ -33 -__ggml_vocab_test__ -333 -__ggml_vocab_test__ -3333 -__ggml_vocab_test__ -33333 -__ggml_vocab_test__ -333333 -__ggml_vocab_test__ -3333333 -__ggml_vocab_test__ -33333333 -__ggml_vocab_test__ -333333333 -__ggml_vocab_test__ -Cửa Việt -__ggml_vocab_test__ - discards -__ggml_vocab_test__ - - - - - - - - - - - -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL -__ggml_vocab_test__ diff --git a/models/ggml-vocab-deepseek-r1-qwen.gguf.out b/models/ggml-vocab-deepseek-r1-qwen.gguf.out deleted file mode 100644 index 18b4b45cd..000000000 --- a/models/ggml-vocab-deepseek-r1-qwen.gguf.out +++ /dev/null @@ -1,46 +0,0 @@ - 1122 220 19 220 26062 3951 - 37 50753 261 - - 220 - 256 - 262 - 197 - 198 - 271 - 1406 - 1572 - 9707 1879 - 21927 1879 - 9707 4337 - 21927 4337 - 21927 4337 0 - 9707 11 1879 0 - 21927 11 1879 0 - 419 374 11162 99 247 13 10821 - 86 15 19 23 220 22 83 1963 41808 11472 2940 16739 - 78762 14144 1456 13073 63471 33594 3038 133178 79012 - 146394 97529 241 44258 233 146568 44258 224 147603 20879 115 146280 44258 223 146280 147272 97529 227 147805 148301 147270 44258 223 146848 - 145836 320 8252 8 26525 114 378 235 149921 30543 320 35673 99066 97534 8 25521 227 320 3243 42365 429 702 1181 1828 3950 8 - 9707 - 21927 - 220 21927 - 256 21927 - 262 21927 - 262 21927 198 262 21927 - 320 - 198 284 - 6 11385 - 9707 11 379 64848 0 2585 525 498 26525 223 937 104100 18493 22377 99257 16 18 16 19 16 20 16 35727 21216 - 17085 2928 - 18 - 18 18 - 18 18 18 - 18 18 18 18 - 18 18 18 18 18 - 18 18 18 18 18 18 - 18 18 18 18 18 18 18 - 18 18 18 18 18 18 18 18 - 18 18 18 18 18 18 18 18 18 - 34 90063 128324 - 2560 2347 - 198 4710 14731 65497 7847 1572 2303 78672 10947 145836 320 8252 8 26525 114 378 235 149921 30543 320 35673 99066 97534 8 25521 227 11162 99 247 149955 220 18 220 18 18 220 18 18 18 220 18 18 18 18 220 18 18 18 18 18 220 18 18 18 18 18 18 220 18 18 18 18 18 18 18 220 18 18 18 18 18 18 18 18 220 18 13 18 220 18 496 18 220 18 1112 18 220 146394 97529 241 44258 233 146568 44258 224 147603 20879 115 146280 44258 223 146280 147272 97529 227 144534 937 104100 18493 22377 99257 16 18 16 19 16 20 16 35727 21216 55460 53237 18658 14144 1456 13073 63471 33594 3038 133178 79012 3355 4605 4605 13874 13874 73594 3014 3014 28149 17085 2928 26610 7646 358 3003 1012 364 83 813 566 594 1052 11 364 787 498 2704 30 364 44 537 2704 358 3278 1281 432 11 364 35 498 1075 1045 15243 30 1205 6 42612 264 63866 43 diff --git a/models/ggml-vocab-falcon.gguf.inp b/models/ggml-vocab-falcon.gguf.inp index 9baf7d77a..86b934e40 100644 --- a/models/ggml-vocab-falcon.gguf.inp +++ b/models/ggml-vocab-falcon.gguf.inp @@ -1,6 +1,6 @@ ied 4 ½ months __ggml_vocab_test__ -Führer +Äpfel __ggml_vocab_test__ __ggml_vocab_test__ diff --git a/models/ggml-vocab-falcon.gguf.out b/models/ggml-vocab-falcon.gguf.out index 64a48d97f..6319de60e 100644 --- a/models/ggml-vocab-falcon.gguf.out +++ b/models/ggml-vocab-falcon.gguf.out @@ -1,5 +1,5 @@ 878 204 31 3068 133 2137 - 28611 132 30042 + 34502 18614 286 204 258 diff --git a/models/ggml-vocab-gpt-2.gguf.inp b/models/ggml-vocab-gpt-2.gguf.inp index 9baf7d77a..86b934e40 100644 --- a/models/ggml-vocab-gpt-2.gguf.inp +++ b/models/ggml-vocab-gpt-2.gguf.inp @@ -1,6 +1,6 @@ ied 4 ½ months __ggml_vocab_test__ -Führer +Äpfel __ggml_vocab_test__ __ggml_vocab_test__ diff --git a/models/ggml-vocab-gpt-2.gguf.out b/models/ggml-vocab-gpt-2.gguf.out index 17a13bdfc..6464ded3d 100644 --- a/models/ggml-vocab-gpt-2.gguf.out +++ b/models/ggml-vocab-gpt-2.gguf.out @@ -1,5 +1,5 @@ 798 604 25208 1933 - 37 9116 71 11751 + 127 226 79 69 417 220 220 220 diff --git a/models/ggml-vocab-gpt-4o.gguf.inp b/models/ggml-vocab-gpt-4o.gguf.inp deleted file mode 100644 index 9baf7d77a..000000000 --- a/models/ggml-vocab-gpt-4o.gguf.inp +++ /dev/null @@ -1,112 +0,0 @@ -ied 4 ½ months -__ggml_vocab_test__ -Führer -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - - -__ggml_vocab_test__ - - - -__ggml_vocab_test__ - - - - -__ggml_vocab_test__ - - -__ggml_vocab_test__ -Hello world -__ggml_vocab_test__ - Hello world -__ggml_vocab_test__ -Hello World -__ggml_vocab_test__ - Hello World -__ggml_vocab_test__ - Hello World! -__ggml_vocab_test__ -Hello, world! -__ggml_vocab_test__ - Hello, world! -__ggml_vocab_test__ - this is 🦙.cpp -__ggml_vocab_test__ -w048 7tuijk dsdfhu -__ggml_vocab_test__ -нещо на Български -__ggml_vocab_test__ -កាន់តែពិសេសអាចខលចេញ -__ggml_vocab_test__ -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) -__ggml_vocab_test__ -Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello - Hello -__ggml_vocab_test__ - ( -__ggml_vocab_test__ - - = -__ggml_vocab_test__ -' era -__ggml_vocab_test__ -Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ -__ggml_vocab_test__ -!!!!!! -__ggml_vocab_test__ -3 -__ggml_vocab_test__ -33 -__ggml_vocab_test__ -333 -__ggml_vocab_test__ -3333 -__ggml_vocab_test__ -33333 -__ggml_vocab_test__ -333333 -__ggml_vocab_test__ -3333333 -__ggml_vocab_test__ -33333333 -__ggml_vocab_test__ -333333333 -__ggml_vocab_test__ -Cửa Việt -__ggml_vocab_test__ - discards -__ggml_vocab_test__ - - - - - - - - - - - -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL -__ggml_vocab_test__ diff --git a/models/ggml-vocab-gpt-4o.gguf.out b/models/ggml-vocab-gpt-4o.gguf.out deleted file mode 100644 index 478df726f..000000000 --- a/models/ggml-vocab-gpt-4o.gguf.out +++ /dev/null @@ -1,46 +0,0 @@ - 1165 220 19 220 27124 5503 - 37 19194 259 - - 220 - 256 - 271 - 197 - 198 - 279 - 2499 - 2775 - 13225 2375 - 32949 2375 - 13225 5922 - 32949 5922 - 32949 5922 0 - 13225 11 2375 0 - 32949 11 2375 0 - 495 382 9552 99 247 13 17159 - 86 45404 220 22 10191 2852 22924 4750 6916 - 3907 53641 1235 185386 8118 - 11400 107516 15867 20804 22851 134178 77431 32010 104312 37984 16329 27751 89335 - 112927 222 350 14559 8 22861 114 2524 64364 104 15148 350 76466 166700 121942 780 8 91349 350 7393 74471 484 853 1617 2316 6602 8 - 13225 - 32949 - 220 32949 - 256 32949 - 271 32949 - 271 32949 198 271 32949 - 350 - 198 314 - 6 6837 - 13225 11 342 70653 0 3253 553 481 22861 223 1423 7522 18165 2178 34058 22369 16412 32999 16 867 8208 - 147475 - 18 - 2546 - 15517 - 15517 18 - 15517 2546 - 15517 15517 - 15517 15517 18 - 15517 15517 2546 - 15517 15517 15517 - 34 60213 53904 - 2960 3098 - 126470 25980 160432 16609 2775 4066 172261 19432 112927 222 350 14559 8 22861 114 2524 64364 104 15148 350 76466 166700 121942 780 8 91349 9552 99 247 4103 99 247 220 18 220 2546 220 15517 220 15517 18 220 15517 2546 220 15517 15517 220 15517 15517 18 220 15517 15517 2546 220 18 13 18 220 18 485 18 220 18 1008 18 44735 107516 15867 20804 22851 134178 77431 32010 104312 156437 1423 7522 18165 2178 34058 22369 16412 32999 16 867 8208 105024 106657 1967 53641 1235 185386 8118 22434 39336 26178 26178 168394 194663 27271 147475 25883 6961 9790 1339 461 83 1280 19016 1354 11 461 1099 481 3239 30 461 44 625 3239 17291 1520 480 11 461 35 481 1299 1236 17966 30 1416 6 27493 261 54602 43 diff --git a/models/ggml-vocab-llama-bpe.gguf.inp b/models/ggml-vocab-llama-bpe.gguf.inp index 9baf7d77a..86b934e40 100644 --- a/models/ggml-vocab-llama-bpe.gguf.inp +++ b/models/ggml-vocab-llama-bpe.gguf.inp @@ -1,6 +1,6 @@ ied 4 ½ months __ggml_vocab_test__ -Führer +Äpfel __ggml_vocab_test__ __ggml_vocab_test__ diff --git a/models/ggml-vocab-llama-bpe.gguf.out b/models/ggml-vocab-llama-bpe.gguf.out index 4b35cf93f..a77376625 100644 --- a/models/ggml-vocab-llama-bpe.gguf.out +++ b/models/ggml-vocab-llama-bpe.gguf.out @@ -1,5 +1,5 @@ 1142 220 19 220 27154 4038 - 37 51853 261 + 88075 16276 301 220 256 diff --git a/models/ggml-vocab-llama-spm.gguf.inp b/models/ggml-vocab-llama-spm.gguf.inp index 9baf7d77a..86b934e40 100644 --- a/models/ggml-vocab-llama-spm.gguf.inp +++ b/models/ggml-vocab-llama-spm.gguf.inp @@ -1,6 +1,6 @@ ied 4 ½ months __ggml_vocab_test__ -Führer +Äpfel __ggml_vocab_test__ __ggml_vocab_test__ diff --git a/models/ggml-vocab-llama-spm.gguf.out b/models/ggml-vocab-llama-spm.gguf.out index 93aacf8ba..2a71a6ef8 100644 --- a/models/ggml-vocab-llama-spm.gguf.out +++ b/models/ggml-vocab-llama-spm.gguf.out @@ -1,5 +1,5 @@ 474 287 29871 29946 29871 30226 7378 - 383 4000 261 + 11585 7810 295 259 1678 diff --git a/models/ggml-vocab-llama4.gguf.inp b/models/ggml-vocab-llama4.gguf.inp deleted file mode 100644 index 9baf7d77a..000000000 --- a/models/ggml-vocab-llama4.gguf.inp +++ /dev/null @@ -1,112 +0,0 @@ -ied 4 ½ months -__ggml_vocab_test__ -Führer -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - - -__ggml_vocab_test__ - - - -__ggml_vocab_test__ - - - - -__ggml_vocab_test__ - - -__ggml_vocab_test__ -Hello world -__ggml_vocab_test__ - Hello world -__ggml_vocab_test__ -Hello World -__ggml_vocab_test__ - Hello World -__ggml_vocab_test__ - Hello World! -__ggml_vocab_test__ -Hello, world! -__ggml_vocab_test__ - Hello, world! -__ggml_vocab_test__ - this is 🦙.cpp -__ggml_vocab_test__ -w048 7tuijk dsdfhu -__ggml_vocab_test__ -нещо на Български -__ggml_vocab_test__ -កាន់តែពិសេសអាចខលចេញ -__ggml_vocab_test__ -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) -__ggml_vocab_test__ -Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello - Hello -__ggml_vocab_test__ - ( -__ggml_vocab_test__ - - = -__ggml_vocab_test__ -' era -__ggml_vocab_test__ -Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ -__ggml_vocab_test__ -!!!!!! -__ggml_vocab_test__ -3 -__ggml_vocab_test__ -33 -__ggml_vocab_test__ -333 -__ggml_vocab_test__ -3333 -__ggml_vocab_test__ -33333 -__ggml_vocab_test__ -333333 -__ggml_vocab_test__ -3333333 -__ggml_vocab_test__ -33333333 -__ggml_vocab_test__ -333333333 -__ggml_vocab_test__ -Cửa Việt -__ggml_vocab_test__ - discards -__ggml_vocab_test__ - - - - - - - - - - - -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL -__ggml_vocab_test__ diff --git a/models/ggml-vocab-llama4.gguf.out b/models/ggml-vocab-llama4.gguf.out deleted file mode 100644 index 7ca46ce59..000000000 --- a/models/ggml-vocab-llama4.gguf.out +++ /dev/null @@ -1,46 +0,0 @@ - 1190 220 32 220 18215 7112 - 50 16800 258 - - 220 - 256 - 277 - 197 - 198 - 368 - 2946 - 3271 - 19873 3817 - 39715 3817 - 19873 7353 - 39715 7353 - 39715 7353 13 - 19873 24 3817 13 - 39715 24 3817 13 - 544 373 9522 112 247 26 36315 - 99 39923 220 35 9607 21498 21470 3679 9433 - 1595 7653 633 79829 34051 1636 - 8755 102595 115960 21125 148305 96819 102816 39048 14105 22528 160234 - 114590 222 330 14879 21 51358 127 12817 93293 117 24204 330 68239 881 120327 170428 21 89101 330 7384 88230 511 947 1492 3742 7233 21 - 19873 - 39715 - 220 39715 - 256 39715 - 277 39715 - 277 39715 198 277 39715 - 330 - 198 319 - 19 7359 - 19873 24 386 87799 13 2403 583 650 51358 223 1663 155736 1522 42056 7544 13336 28785 29 4412 20645 - 17931 4959 - 31 - 1922 - 12325 - 12325 31 - 12325 1922 - 12325 12325 - 12325 12325 31 - 12325 12325 1922 - 12325 12325 12325 - 47 19811 12077 - 3260 3579 - 198 7283 51499 191231 20192 3271 3322 9287 2143 17860 114590 222 330 14879 21 51358 127 12817 93293 117 24204 330 68239 881 120327 170428 21 89101 9522 112 247 172394 247 220 31 220 1922 220 12325 220 12325 31 220 12325 1922 220 12325 12325 220 12325 12325 31 220 12325 12325 1922 220 31 26 31 220 31 396 31 220 31 1043 31 117131 102595 115960 21125 148305 96819 102816 80883 223 1663 155736 1522 42056 7544 13336 28785 29 4412 20645 79745 150278 117079 633 79829 34051 1636 25611 41990 109428 1488 91054 24072 17931 4959 29795 9296 16517 1806 481 96 1386 36633 1609 24 481 1109 650 5074 43 481 57 702 5074 27088 2170 536 24 481 48 650 1933 1696 30262 43 1665 19 32818 262 27236 56 diff --git a/models/ggml-vocab-mpt.gguf.inp b/models/ggml-vocab-mpt.gguf.inp index 9baf7d77a..86b934e40 100644 --- a/models/ggml-vocab-mpt.gguf.inp +++ b/models/ggml-vocab-mpt.gguf.inp @@ -1,6 +1,6 @@ ied 4 ½ months __ggml_vocab_test__ -Führer +Äpfel __ggml_vocab_test__ __ggml_vocab_test__ diff --git a/models/ggml-vocab-mpt.gguf.out b/models/ggml-vocab-mpt.gguf.out index 372c751bf..ca62669ad 100644 --- a/models/ggml-vocab-mpt.gguf.out +++ b/models/ggml-vocab-mpt.gguf.out @@ -1,5 +1,5 @@ 728 577 24142 2607 - 39 26288 6554 + 37515 18569 293 209 50276 diff --git a/models/ggml-vocab-nomic-bert-moe.gguf b/models/ggml-vocab-nomic-bert-moe.gguf new file mode 100644 index 000000000..b6f4d9441 Binary files /dev/null and b/models/ggml-vocab-nomic-bert-moe.gguf differ diff --git a/models/ggml-vocab-phi-3.gguf.inp b/models/ggml-vocab-phi-3.gguf.inp index 9baf7d77a..86b934e40 100644 --- a/models/ggml-vocab-phi-3.gguf.inp +++ b/models/ggml-vocab-phi-3.gguf.inp @@ -1,6 +1,6 @@ ied 4 ½ months __ggml_vocab_test__ -Führer +Äpfel __ggml_vocab_test__ __ggml_vocab_test__ diff --git a/models/ggml-vocab-phi-3.gguf.out b/models/ggml-vocab-phi-3.gguf.out index 93aacf8ba..2a71a6ef8 100644 --- a/models/ggml-vocab-phi-3.gguf.out +++ b/models/ggml-vocab-phi-3.gguf.out @@ -1,5 +1,5 @@ 474 287 29871 29946 29871 30226 7378 - 383 4000 261 + 11585 7810 295 259 1678 diff --git a/models/ggml-vocab-qwen2.gguf.inp b/models/ggml-vocab-qwen2.gguf.inp index 9baf7d77a..86b934e40 100644 --- a/models/ggml-vocab-qwen2.gguf.inp +++ b/models/ggml-vocab-qwen2.gguf.inp @@ -1,6 +1,6 @@ ied 4 ½ months __ggml_vocab_test__ -Führer +Äpfel __ggml_vocab_test__ __ggml_vocab_test__ diff --git a/models/ggml-vocab-qwen2.gguf.out b/models/ggml-vocab-qwen2.gguf.out index 18b4b45cd..595d59a44 100644 --- a/models/ggml-vocab-qwen2.gguf.out +++ b/models/ggml-vocab-qwen2.gguf.out @@ -1,5 +1,5 @@ 1122 220 19 220 26062 3951 - 37 50753 261 + 86975 15897 301 220 256 diff --git a/models/ggml-vocab-refact.gguf.inp b/models/ggml-vocab-refact.gguf.inp index 9baf7d77a..86b934e40 100644 --- a/models/ggml-vocab-refact.gguf.inp +++ b/models/ggml-vocab-refact.gguf.inp @@ -1,6 +1,6 @@ ied 4 ½ months __ggml_vocab_test__ -Führer +Äpfel __ggml_vocab_test__ __ggml_vocab_test__ diff --git a/models/ggml-vocab-refact.gguf.out b/models/ggml-vocab-refact.gguf.out index 63d8305c3..f13dda52c 100644 --- a/models/ggml-vocab-refact.gguf.out +++ b/models/ggml-vocab-refact.gguf.out @@ -1,5 +1,5 @@ 4833 225 38 225 143 140 17723 - 56 2006 3935 265 + 144 231 7132 342 225 261 diff --git a/models/ggml-vocab-roberta-bpe.gguf.inp b/models/ggml-vocab-roberta-bpe.gguf.inp deleted file mode 100644 index 9baf7d77a..000000000 --- a/models/ggml-vocab-roberta-bpe.gguf.inp +++ /dev/null @@ -1,112 +0,0 @@ -ied 4 ½ months -__ggml_vocab_test__ -Führer -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - -__ggml_vocab_test__ - - -__ggml_vocab_test__ - - - -__ggml_vocab_test__ - - - - -__ggml_vocab_test__ - - -__ggml_vocab_test__ -Hello world -__ggml_vocab_test__ - Hello world -__ggml_vocab_test__ -Hello World -__ggml_vocab_test__ - Hello World -__ggml_vocab_test__ - Hello World! -__ggml_vocab_test__ -Hello, world! -__ggml_vocab_test__ - Hello, world! -__ggml_vocab_test__ - this is 🦙.cpp -__ggml_vocab_test__ -w048 7tuijk dsdfhu -__ggml_vocab_test__ -нещо на Български -__ggml_vocab_test__ -កាន់តែពិសេសអាចខលចេញ -__ggml_vocab_test__ -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token) -__ggml_vocab_test__ -Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello -__ggml_vocab_test__ - Hello - Hello -__ggml_vocab_test__ - ( -__ggml_vocab_test__ - - = -__ggml_vocab_test__ -' era -__ggml_vocab_test__ -Hello, y'all! How are you 😁 ?我想在apple工作1314151天~ -__ggml_vocab_test__ -!!!!!! -__ggml_vocab_test__ -3 -__ggml_vocab_test__ -33 -__ggml_vocab_test__ -333 -__ggml_vocab_test__ -3333 -__ggml_vocab_test__ -33333 -__ggml_vocab_test__ -333333 -__ggml_vocab_test__ -3333333 -__ggml_vocab_test__ -33333333 -__ggml_vocab_test__ -333333333 -__ggml_vocab_test__ -Cửa Việt -__ggml_vocab_test__ - discards -__ggml_vocab_test__ - - - - - - - - - - - -🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL -__ggml_vocab_test__ diff --git a/models/ggml-vocab-roberta-bpe.gguf.out b/models/ggml-vocab-roberta-bpe.gguf.out deleted file mode 100644 index f181ac3dc..000000000 --- a/models/ggml-vocab-roberta-bpe.gguf.out +++ /dev/null @@ -1,46 +0,0 @@ - 2550 204 18430 377 - 597 2768 298 8564 - - 1437 - 1437 1437 - 1437 1437 1437 - 50117 - 50118 - 50140 - 50140 50118 - 50117 50118 - 31414 232 - 20920 232 - 31414 623 - 20920 623 - 20920 623 328 - 31414 6 232 328 - 20920 6 232 328 - 42 16 8103 18164 27 4 49317 - 605 40976 262 10109 18474 385 29 36807 6455 - 36765 25482 22063 23171 34251 18697 10809 26161 18697 3602 22063 27969 40966 25417 15264 26161 24269 36709 41171 35328 - 1376 17772 7471 1376 17772 19002 1376 17772 9085 1376 4333 13859 1376 17772 9357 1376 4333 9264 1376 17772 25448 1376 17772 18400 1376 17772 4333 1376 4333 10172 1376 17772 4333 1376 17772 7258 1376 17772 19002 1376 17772 5782 1376 17772 10172 1376 17772 3726 1376 17772 5782 1376 4333 10172 1376 17772 23171 - 6569 15113 7471 36 21113 43 17841 19002 17 8384 6569 14285 4958 12605 36 34654 2841 4203 354 10146 26511 1070 43 36174 5782 36 8338 21554 14 34 63 308 19233 43 - 31414 - 20920 - 1437 20920 - 1437 1437 20920 - 1437 1437 1437 20920 - 1437 1437 1437 20920 50118 1437 1437 1437 20920 - 36 - 50118 5457 - 108 3567 - 31414 6 1423 108 1250 328 1336 32 47 17841 10172 17487 47876 3602 48617 15264 46537 11423 27326 48494 8210 49233 1558 1570 27761 49429 43251 10809 17772 - 32376 12846 - 246 - 3103 - 25631 - 46152 - 3103 25631 - 46152 3103 - 46152 25631 - 46152 46152 - 46152 3103 25631 - 347 1376 2023 12410 102 16376 1376 2023 6382 90 - 9553 5954 - 50118 1437 50140 1437 50140 50118 1437 50117 1437 50117 50117 1437 50117 50118 1437 1437 50118 1437 1437 1437 50118 1437 1437 1437 1437 50118 1437 1437 1437 1437 1437 50118 6569 15113 7471 36 21113 43 17841 19002 17 8384 6569 14285 4958 12605 36 34654 2841 4203 354 10146 26511 1070 43 36174 5782 8103 18164 27 6569 18164 27 155 2357 30242 155 25631 30242 3103 30242 25631 30242 46152 30242 3103 25631 155 4 246 155 7586 246 155 734 246 25974 17772 7471 1376 17772 19002 1376 17772 9085 1376 4333 13859 1376 17772 9357 1376 4333 9264 1376 17772 25448 1376 17772 18400 1376 17772 4333 1376 4333 10172 1376 17772 4333 1376 17772 7258 1376 17772 19002 1376 17772 5782 18636 10172 17487 47876 3602 48617 15264 46537 11423 27326 48494 8210 49233 1558 1570 27761 49429 43251 10809 17772 36738 48332 47463 18697 10809 25482 22063 23171 34251 18697 10809 26161 18697 3602 22063 27969 40966 25417 15264 26161 24269 36709 41171 35328 128 49690 108 49972 49519 12905 48149 48149 43796 32376 12846 27282 28749 38 348 57 128 41042 37 18 89 6 128 4629 47 686 116 128 448 45 686 38 581 146 24 6 128 495 47 101 103 6845 116 166 108 30660 10 108 462 574 diff --git a/models/ggml-vocab-starcoder.gguf.inp b/models/ggml-vocab-starcoder.gguf.inp index 9baf7d77a..86b934e40 100644 --- a/models/ggml-vocab-starcoder.gguf.inp +++ b/models/ggml-vocab-starcoder.gguf.inp @@ -1,6 +1,6 @@ ied 4 ½ months __ggml_vocab_test__ -Führer +Äpfel __ggml_vocab_test__ __ggml_vocab_test__ diff --git a/models/ggml-vocab-starcoder.gguf.out b/models/ggml-vocab-starcoder.gguf.out index 87e2465d3..4698e2c3c 100644 --- a/models/ggml-vocab-starcoder.gguf.out +++ b/models/ggml-vocab-starcoder.gguf.out @@ -1,5 +1,5 @@ 4850 244 57 244 162 159 17722 - 75 2022 3943 284 + 163 250 7146 361 244 280 diff --git a/models/templates/Qwen-QwQ-32B.jinja b/models/templates/Qwen-QwQ-32B.jinja new file mode 100644 index 000000000..d475f7068 --- /dev/null +++ b/models/templates/Qwen-QwQ-32B.jinja @@ -0,0 +1,62 @@ +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0]['role'] == 'system' %} + {{- messages[0]['content'] }} + {%- else %} + {{- '' }} + {%- endif %} + {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0]['role'] == 'system' %} + {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in messages %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" and not message.tool_calls %} + {%- set content = message.content %} + {%- if not loop.last %} + {%- set content = message.content.split('
')[-1].lstrip('\n') %} + {%- endif %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set content = message.content %} + {%- if not loop.last %} + {%- set content = message.content.split('')[-1].lstrip('\n') %} + {%- endif %} + {{- '<|im_start|>' + message.role }} + {%- if message.content %} + {{- '\n' + content }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {{- tool_call.arguments | tojson }} + {{- '}\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n\n' }} +{%- endif %} diff --git a/models/templates/Qwen-Qwen3-0.6B.jinja b/models/templates/Qwen-Qwen3-0.6B.jinja new file mode 100644 index 000000000..699ff8df4 --- /dev/null +++ b/models/templates/Qwen-Qwen3-0.6B.jinja @@ -0,0 +1,85 @@ +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- messages[0].content + '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('') and message.content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} +{%- endfor %} +{%- for message in messages %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set content = message.content %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is defined and message.reasoning_content is not none %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in message.content %} + {%- set content = message.content.split('')[-1].lstrip('\n') %} + {%- set reasoning_content = message.content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- if loop.index0 > ns.last_query_index %} + {%- if loop.last or (not loop.last and reasoning_content) %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file diff --git a/models/templates/README.md b/models/templates/README.md index e4fd104fc..35b6386dd 100644 --- a/models/templates/README.md +++ b/models/templates/README.md @@ -19,4 +19,6 @@ These templates can be updated with the following commands: ./scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use > models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja ./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use > models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja ./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja +./scripts/get_chat_template.py Qwen/QwQ-32B > models/templates/Qwen-QwQ-32B.jinja +./scripts/get_chat_template.py Qwen/Qwen3-0.6B > models/templates/Qwen-Qwen3-0.6B.jinja ``` \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index ed62264ba..3d71b055a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,5 +40,6 @@ build-backend = "poetry.core.masonry.api" [tool.poetry.scripts] llama-convert-hf-to-gguf = "convert_hf_to_gguf:main" +llama-convert-lora-to-gguf = "convert_lora_to_gguf:main" llama-convert-llama-ggml-to-gguf = "convert_llama_ggml_to_gguf:main" llama-ggml-vk-generate-shaders = "ggml_vk_generate_shaders:main" diff --git a/pyrightconfig.json b/pyrightconfig.json index 9acbbeb78..5320fe586 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -15,7 +15,7 @@ }, { // uses match expressions in steps.py - "root": "examples/server/tests", + "root": "tools/server/tests", "pythonVersion": "3.10", }, ], diff --git a/requirements/requirements-all.txt b/requirements/requirements-all.txt index 439db8886..9fa7d4d0a 100644 --- a/requirements/requirements-all.txt +++ b/requirements/requirements-all.txt @@ -1,6 +1,6 @@ --r ../examples/llava/requirements.txt --r ../examples/server/bench/requirements.txt --r ../examples/server/tests/requirements.txt +-r ../tools/mtmd/requirements.txt +-r ../tools/server/bench/requirements.txt +-r ../tools/server/tests/requirements.txt -r ./requirements-compare-llama-bench.txt -r ./requirements-pydantic.txt @@ -11,3 +11,5 @@ -r ./requirements-convert_legacy_llama.txt -r ./requirements-convert_llama_ggml_to_gguf.txt -r ./requirements-tool_bench.txt + +-r ./requirements-gguf_editor_gui.txt diff --git a/requirements/requirements-compare-llama-bench.txt b/requirements/requirements-compare-llama-bench.txt index e0aaa3204..d87e897e1 100644 --- a/requirements/requirements-compare-llama-bench.txt +++ b/requirements/requirements-compare-llama-bench.txt @@ -1,2 +1,3 @@ tabulate~=0.9.0 GitPython~=3.1.43 +matplotlib~=3.10.0 diff --git a/requirements/requirements-convert_hf_to_gguf.txt b/requirements/requirements-convert_hf_to_gguf.txt index 8cb9c354f..431c596c1 100644 --- a/requirements/requirements-convert_hf_to_gguf.txt +++ b/requirements/requirements-convert_hf_to_gguf.txt @@ -1,3 +1,7 @@ -r ./requirements-convert_legacy_llama.txt --extra-index-url https://download.pytorch.org/whl/cpu -torch~=2.2.1 +torch~=2.2.1; platform_machine != "s390x" + +# torch s390x packages can only be found from nightly builds +--extra-index-url https://download.pytorch.org/whl/nightly +torch>=0.0.0.dev0; platform_machine == "s390x" diff --git a/requirements/requirements-convert_hf_to_gguf_update.txt b/requirements/requirements-convert_hf_to_gguf_update.txt index 8cb9c354f..431c596c1 100644 --- a/requirements/requirements-convert_hf_to_gguf_update.txt +++ b/requirements/requirements-convert_hf_to_gguf_update.txt @@ -1,3 +1,7 @@ -r ./requirements-convert_legacy_llama.txt --extra-index-url https://download.pytorch.org/whl/cpu -torch~=2.2.1 +torch~=2.2.1; platform_machine != "s390x" + +# torch s390x packages can only be found from nightly builds +--extra-index-url https://download.pytorch.org/whl/nightly +torch>=0.0.0.dev0; platform_machine == "s390x" diff --git a/requirements/requirements-convert_lora_to_gguf.txt b/requirements/requirements-convert_lora_to_gguf.txt index 5758076c4..d091d5648 100644 --- a/requirements/requirements-convert_lora_to_gguf.txt +++ b/requirements/requirements-convert_lora_to_gguf.txt @@ -1,2 +1,4 @@ -r ./requirements-convert_hf_to_gguf.txt --extra-index-url https://download.pytorch.org/whl/cpu +# torch s390x packages can only be found from nightly builds +--extra-index-url https://download.pytorch.org/whl/nightly diff --git a/requirements/requirements-gguf_editor_gui.txt b/requirements/requirements-gguf_editor_gui.txt new file mode 100644 index 000000000..fd253364e --- /dev/null +++ b/requirements/requirements-gguf_editor_gui.txt @@ -0,0 +1,3 @@ +numpy~=1.26.4 +PySide6~=6.9.0 +gguf>=0.17.0 diff --git a/scripts/compare-commits.sh b/scripts/compare-commits.sh index e40d1cc6d..94a8eceb3 100755 --- a/scripts/compare-commits.sh +++ b/scripts/compare-commits.sh @@ -17,14 +17,14 @@ rm -f llama-bench.sqlite > /dev/null # to test a backend, call the script with the corresponding environment variable (e.g. GGML_CUDA=1 ./scripts/compare-commits.sh ...) if [ -n "$GGML_CUDA" ]; then - cmake_opts="-DGGML_CUDA=ON" + CMAKE_OPTS="${CMAKE_OPTS} -DGGML_CUDA=ON" fi dir="build-bench" function run { rm -fr ${dir} > /dev/null - cmake -B ${dir} -S . $cmake_opts > /dev/null + cmake -B ${dir} -S . ${CMAKE_OPTS} > /dev/null cmake --build ${dir} -t llama-bench > /dev/null ${dir}/bin/llama-bench -o sql -oe md $bench_args | sqlite3 llama-bench.sqlite } diff --git a/scripts/compare-llama-bench.py b/scripts/compare-llama-bench.py index 6205fe88d..30e3cf864 100755 --- a/scripts/compare-llama-bench.py +++ b/scripts/compare-llama-bench.py @@ -7,6 +7,10 @@ import sys import os from glob import glob import sqlite3 +import json +import csv +from typing import Optional, Union +from collections.abc import Iterator, Sequence try: import git @@ -15,13 +19,36 @@ except ImportError as e: print("the following Python libraries are required: GitPython, tabulate.") # noqa: NP100 raise e + logger = logging.getLogger("compare-llama-bench") +# All llama-bench SQL fields +DB_FIELDS = [ + "build_commit", "build_number", "cpu_info", "gpu_info", "backends", "model_filename", + "model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", "n_threads", + "cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers", + "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "tensor_buft_overrides", + "defrag_thold", + "use_mmap", "embeddings", "no_op_offload", "n_prompt", "n_gen", "n_depth", + "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts", +] + +DB_TYPES = [ + "TEXT", "INTEGER", "TEXT", "TEXT", "TEXT", "TEXT", + "TEXT", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", + "TEXT", "INTEGER", "INTEGER", "TEXT", "TEXT", "INTEGER", + "TEXT", "INTEGER", "INTEGER", "INTEGER", "TEXT", "TEXT", + "REAL", + "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", + "TEXT", "INTEGER", "INTEGER", "REAL", "REAL", +] +assert len(DB_FIELDS) == len(DB_TYPES) + # Properties by which to differentiate results per commit: KEY_PROPERTIES = [ - "cpu_info", "gpu_info", "backends", "n_gpu_layers", "model_filename", "model_type", "n_batch", "n_ubatch", - "embeddings", "cpu_mask", "cpu_strict", "poll", "n_threads", "type_k", "type_v", "use_mmap", "no_kv_offload", - "split_mode", "main_gpu", "tensor_split", "flash_attn", "n_prompt", "n_gen" + "cpu_info", "gpu_info", "backends", "n_gpu_layers", "tensor_buft_overrides", "model_filename", "model_type", + "n_batch", "n_ubatch", "embeddings", "cpu_mask", "cpu_strict", "poll", "n_threads", "type_k", "type_v", + "use_mmap", "no_kv_offload", "split_mode", "main_gpu", "tensor_split", "flash_attn", "n_prompt", "n_gen", "n_depth" ] # Properties that are boolean and are converted to Yes/No for the table: @@ -30,11 +57,11 @@ BOOL_PROPERTIES = ["embeddings", "cpu_strict", "use_mmap", "no_kv_offload", "fla # Header names for the table: PRETTY_NAMES = { "cpu_info": "CPU", "gpu_info": "GPU", "backends": "Backends", "n_gpu_layers": "GPU layers", - "model_filename": "File", "model_type": "Model", "model_size": "Model size [GiB]", - "model_n_params": "Num. of par.", "n_batch": "Batch size", "n_ubatch": "Microbatch size", - "embeddings": "Embeddings", "cpu_mask": "CPU mask", "cpu_strict": "CPU strict", "poll": "Poll", - "n_threads": "Threads", "type_k": "K type", "type_v": "V type", "split_mode": "Split mode", "main_gpu": "Main GPU", - "no_kv_offload": "NKVO", "flash_attn": "FlashAttention", "tensor_split": "Tensor split", "use_mmap": "Use mmap", + "tensor_buft_overrides": "Tensor overrides", "model_filename": "File", "model_type": "Model", "model_size": "Model size [GiB]", + "model_n_params": "Num. of par.", "n_batch": "Batch size", "n_ubatch": "Microbatch size", "embeddings": "Embeddings", + "cpu_mask": "CPU mask", "cpu_strict": "CPU strict", "poll": "Poll", "n_threads": "Threads", "type_k": "K type", "type_v": "V type", + "use_mmap": "Use mmap", "no_kv_offload": "NKVO", "split_mode": "Split mode", "main_gpu": "Main GPU", "tensor_split": "Tensor split", + "flash_attn": "FlashAttention", } DEFAULT_SHOW = ["model_type"] # Always show these properties by default. @@ -42,7 +69,7 @@ DEFAULT_HIDE = ["model_filename"] # Always hide these properties by default. GPU_NAME_STRIP = ["NVIDIA GeForce ", "Tesla ", "AMD Radeon "] # Strip prefixes for smaller tables. MODEL_SUFFIX_REPLACE = {" - Small": "_S", " - Medium": "_M", " - Large": "_L"} -DESCRIPTION = """Creates tables from llama-bench data written to an SQLite database. Example usage (Linux): +DESCRIPTION = """Creates tables from llama-bench data written to multiple JSON/CSV files, a single JSONL file or SQLite database. Example usage (Linux): $ git checkout master $ make clean && make llama-bench @@ -70,12 +97,13 @@ help_c = ( ) parser.add_argument("-c", "--compare", help=help_c) help_i = ( - "Input SQLite file for comparing commits. " + "JSON/JSONL/SQLite/CSV files for comparing commits. " + "Specify multiple times to use multiple input files (JSON/CSV only). " "Defaults to 'llama-bench.sqlite' in the current working directory. " "If no such file is found and there is exactly one .sqlite file in the current directory, " "that file is instead used as input." ) -parser.add_argument("-i", "--input", help=help_i) +parser.add_argument("-i", "--input", action="append", help=help_i) help_o = ( "Output format for the table. " "Defaults to 'pipe' (GitHub compatible). " @@ -86,7 +114,7 @@ parser.add_argument("-o", "--output", help=help_o, default="pipe") help_s = ( "Columns to add to the table. " "Accepts a comma-separated list of values. " - f"Legal values: {', '.join(KEY_PROPERTIES[:-2])}. " + f"Legal values: {', '.join(KEY_PROPERTIES[:-3])}. " "Defaults to model name (model_type) and CPU and/or GPU name (cpu_info, gpu_info) " "plus any column where not all data points are the same. " "If the columns are manually specified, then the results for each unique combination of the " @@ -95,11 +123,15 @@ help_s = ( parser.add_argument("--check", action="store_true", help="check if all required Python libraries are installed") parser.add_argument("-s", "--show", help=help_s) parser.add_argument("--verbose", action="store_true", help="increase output verbosity") +parser.add_argument("--plot", help="generate a performance comparison plot and save to specified file (e.g., plot.png)") +parser.add_argument("--plot_x", help="parameter to use as x axis for plotting (default: n_depth)", default="n_depth") +parser.add_argument("--plot_log_scale", action="store_true", help="use log scale for x axis in plots (off by default)") known_args, unknown_args = parser.parse_known_args() logging.basicConfig(level=logging.DEBUG if known_args.verbose else logging.INFO) + if known_args.check: # Check if all required Python libraries are installed. Would have failed earlier if not. sys.exit(0) @@ -110,119 +142,321 @@ if unknown_args: sys.exit(1) input_file = known_args.input -if input_file is None and os.path.exists("./llama-bench.sqlite"): - input_file = "llama-bench.sqlite" -if input_file is None: +if not input_file and os.path.exists("./llama-bench.sqlite"): + input_file = ["llama-bench.sqlite"] +if not input_file: sqlite_files = glob("*.sqlite") if len(sqlite_files) == 1: - input_file = sqlite_files[0] + input_file = sqlite_files -if input_file is None: +if not input_file: logger.error("Cannot find a suitable input file, please provide one.\n") parser.print_help() sys.exit(1) -connection = sqlite3.connect(input_file) -cursor = connection.cursor() -build_len_min: int = cursor.execute("SELECT MIN(LENGTH(build_commit)) from test;").fetchone()[0] -build_len_max: int = cursor.execute("SELECT MAX(LENGTH(build_commit)) from test;").fetchone()[0] +class LlamaBenchData: + repo: Optional[git.Repo] + build_len_min: int + build_len_max: int + build_len: int = 8 + builds: list[str] = [] + check_keys = set(KEY_PROPERTIES + ["build_commit", "test_time", "avg_ts"]) -if build_len_min != build_len_max: - logger.warning(f"{input_file} contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. " - "Try purging the the database of old commits.") - cursor.execute(f"UPDATE test SET build_commit = SUBSTRING(build_commit, 1, {build_len_min});") + def __init__(self): + try: + self.repo = git.Repo(".", search_parent_directories=True) + except git.InvalidGitRepositoryError: + self.repo = None -build_len: int = build_len_min + def _builds_init(self): + self.build_len = self.build_len_min -builds = cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall() -builds = list(map(lambda b: b[0], builds)) # list[tuple[str]] -> list[str] - -if not builds: - raise RuntimeError(f"{input_file} does not contain any builds.") - -try: - repo = git.Repo(".", search_parent_directories=True) -except git.InvalidGitRepositoryError: - repo = None - - -def find_parent_in_data(commit: git.Commit): - """Helper function to find the most recent parent measured in number of commits for which there is data.""" - heap: list[tuple[int, git.Commit]] = [(0, commit)] - seen_hexsha8 = set() - while heap: - depth, current_commit = heapq.heappop(heap) - current_hexsha8 = commit.hexsha[:build_len] - if current_hexsha8 in builds: - return current_hexsha8 - for parent in commit.parents: - parent_hexsha8 = parent.hexsha[:build_len] - if parent_hexsha8 not in seen_hexsha8: - seen_hexsha8.add(parent_hexsha8) - heapq.heappush(heap, (depth + 1, parent)) - return None - - -def get_all_parent_hexsha8s(commit: git.Commit): - """Helper function to recursively get hexsha8 values for all parents of a commit.""" - unvisited = [commit] - visited = [] - - while unvisited: - current_commit = unvisited.pop(0) - visited.append(current_commit.hexsha[:build_len]) - for parent in current_commit.parents: - if parent.hexsha[:build_len] not in visited: - unvisited.append(parent) - - return visited - - -def get_commit_name(hexsha8: str): - """Helper function to find a human-readable name for a commit if possible.""" - if repo is None: - return hexsha8 - for h in repo.heads: - if h.commit.hexsha[:build_len] == hexsha8: - return h.name - for t in repo.tags: - if t.commit.hexsha[:build_len] == hexsha8: - return t.name - return hexsha8 - - -def get_commit_hexsha8(name: str): - """Helper function to search for a commit given a human-readable name.""" - if repo is None: + def _check_keys(self, keys: set) -> Optional[set]: + """Private helper method that checks against required data keys and returns missing ones.""" + if not keys >= self.check_keys: + return self.check_keys - keys return None - for h in repo.heads: - if h.name == name: - return h.commit.hexsha[:build_len] - for t in repo.tags: - if t.name == name: - return t.commit.hexsha[:build_len] - for c in repo.iter_commits("--all"): - if c.hexsha[:build_len] == name[:build_len]: - return c.hexsha[:build_len] - return None + + def find_parent_in_data(self, commit: git.Commit) -> Optional[str]: + """Helper method to find the most recent parent measured in number of commits for which there is data.""" + heap: list[tuple[int, git.Commit]] = [(0, commit)] + seen_hexsha8 = set() + while heap: + depth, current_commit = heapq.heappop(heap) + current_hexsha8 = commit.hexsha[:self.build_len] + if current_hexsha8 in self.builds: + return current_hexsha8 + for parent in commit.parents: + parent_hexsha8 = parent.hexsha[:self.build_len] + if parent_hexsha8 not in seen_hexsha8: + seen_hexsha8.add(parent_hexsha8) + heapq.heappush(heap, (depth + 1, parent)) + return None + + def get_all_parent_hexsha8s(self, commit: git.Commit) -> Sequence[str]: + """Helper method to recursively get hexsha8 values for all parents of a commit.""" + unvisited = [commit] + visited = [] + + while unvisited: + current_commit = unvisited.pop(0) + visited.append(current_commit.hexsha[:self.build_len]) + for parent in current_commit.parents: + if parent.hexsha[:self.build_len] not in visited: + unvisited.append(parent) + + return visited + + def get_commit_name(self, hexsha8: str) -> str: + """Helper method to find a human-readable name for a commit if possible.""" + if self.repo is None: + return hexsha8 + for h in self.repo.heads: + if h.commit.hexsha[:self.build_len] == hexsha8: + return h.name + for t in self.repo.tags: + if t.commit.hexsha[:self.build_len] == hexsha8: + return t.name + return hexsha8 + + def get_commit_hexsha8(self, name: str) -> Optional[str]: + """Helper method to search for a commit given a human-readable name.""" + if self.repo is None: + return None + for h in self.repo.heads: + if h.name == name: + return h.commit.hexsha[:self.build_len] + for t in self.repo.tags: + if t.name == name: + return t.commit.hexsha[:self.build_len] + for c in self.repo.iter_commits("--all"): + if c.hexsha[:self.build_len] == name[:self.build_len]: + return c.hexsha[:self.build_len] + return None + + def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequence[tuple]]: + """Helper method that gets rows of (build_commit, test_time) sorted by the latter.""" + return [] + + def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]: + """ + Helper method that gets table rows for some list of properties. + Rows are created by combining those where all provided properties are equal. + The resulting rows are then grouped by the provided properties and the t/s values are averaged. + The returned rows are unique in terms of property combinations. + """ + return [] + + +class LlamaBenchDataSQLite3(LlamaBenchData): + connection: sqlite3.Connection + cursor: sqlite3.Cursor + + def __init__(self): + super().__init__() + self.connection = sqlite3.connect(":memory:") + self.cursor = self.connection.cursor() + self.cursor.execute(f"CREATE TABLE test({', '.join(' '.join(x) for x in zip(DB_FIELDS, DB_TYPES))});") + + def _builds_init(self): + if self.connection: + self.build_len_min = self.cursor.execute("SELECT MIN(LENGTH(build_commit)) from test;").fetchone()[0] + self.build_len_max = self.cursor.execute("SELECT MAX(LENGTH(build_commit)) from test;").fetchone()[0] + + if self.build_len_min != self.build_len_max: + logger.warning("Data contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. " + "Try purging the the database of old commits.") + self.cursor.execute(f"UPDATE test SET build_commit = SUBSTRING(build_commit, 1, {self.build_len_min});") + + builds = self.cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall() + self.builds = list(map(lambda b: b[0], builds)) # list[tuple[str]] -> list[str] + super()._builds_init() + + def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequence[tuple]]: + data = self.cursor.execute( + "SELECT build_commit, test_time FROM test ORDER BY test_time;").fetchall() + return reversed(data) if reverse else data + + def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]: + select_string = ", ".join( + [f"tb.{p}" for p in properties] + ["tb.n_prompt", "tb.n_gen", "tb.n_depth", "AVG(tb.avg_ts)", "AVG(tc.avg_ts)"]) + equal_string = " AND ".join( + [f"tb.{p} = tc.{p}" for p in KEY_PROPERTIES] + [ + f"tb.build_commit = '{hexsha8_baseline}'", f"tc.build_commit = '{hexsha8_compare}'"] + ) + group_order_string = ", ".join([f"tb.{p}" for p in properties] + ["tb.n_gen", "tb.n_prompt", "tb.n_depth"]) + query = (f"SELECT {select_string} FROM test tb JOIN test tc ON {equal_string} " + f"GROUP BY {group_order_string} ORDER BY {group_order_string};") + return self.cursor.execute(query).fetchall() + + +class LlamaBenchDataSQLite3File(LlamaBenchDataSQLite3): + def __init__(self, data_file: str): + super().__init__() + + self.connection.close() + self.connection = sqlite3.connect(data_file) + self.cursor = self.connection.cursor() + self._builds_init() + + @staticmethod + def valid_format(data_file: str) -> bool: + connection = sqlite3.connect(data_file) + cursor = connection.cursor() + + try: + if cursor.execute("PRAGMA schema_version;").fetchone()[0] == 0: + raise sqlite3.DatabaseError("The provided input file does not exist or is empty.") + except sqlite3.DatabaseError as e: + logger.debug(f'"{data_file}" is not a valid SQLite3 file.', exc_info=e) + cursor = None + + connection.close() + return True if cursor else False + + +class LlamaBenchDataJSONL(LlamaBenchDataSQLite3): + def __init__(self, data_file: str): + super().__init__() + + with open(data_file, "r", encoding="utf-8") as fp: + for i, line in enumerate(fp): + parsed = json.loads(line) + + for k in parsed.keys() - set(DB_FIELDS): + del parsed[k] + + if (missing_keys := self._check_keys(parsed.keys())): + raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}") + + self.cursor.execute(f"INSERT INTO test({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", tuple(parsed.values())) + + self._builds_init() + + @staticmethod + def valid_format(data_file: str) -> bool: + try: + with open(data_file, "r", encoding="utf-8") as fp: + for line in fp: + json.loads(line) + break + except Exception as e: + logger.debug(f'"{data_file}" is not a valid JSONL file.', exc_info=e) + return False + + return True + + +class LlamaBenchDataJSON(LlamaBenchDataSQLite3): + def __init__(self, data_files: list[str]): + super().__init__() + + for data_file in data_files: + with open(data_file, "r", encoding="utf-8") as fp: + parsed = json.load(fp) + + for i, entry in enumerate(parsed): + for k in entry.keys() - set(DB_FIELDS): + del entry[k] + + if (missing_keys := self._check_keys(entry.keys())): + raise RuntimeError(f"Missing required data key(s) at entry {i + 1}: {', '.join(missing_keys)}") + + self.cursor.execute(f"INSERT INTO test({', '.join(entry.keys())}) VALUES({', '.join('?' * len(entry))});", tuple(entry.values())) + + self._builds_init() + + @staticmethod + def valid_format(data_files: list[str]) -> bool: + if not data_files: + return False + + for data_file in data_files: + try: + with open(data_file, "r", encoding="utf-8") as fp: + json.load(fp) + except Exception as e: + logger.debug(f'"{data_file}" is not a valid JSON file.', exc_info=e) + return False + + return True + + +class LlamaBenchDataCSV(LlamaBenchDataSQLite3): + def __init__(self, data_files: list[str]): + super().__init__() + + for data_file in data_files: + with open(data_file, "r", encoding="utf-8") as fp: + for i, parsed in enumerate(csv.DictReader(fp)): + keys = set(parsed.keys()) + + for k in keys - set(DB_FIELDS): + del parsed[k] + + if (missing_keys := self._check_keys(keys)): + raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}") + + self.cursor.execute(f"INSERT INTO test({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", tuple(parsed.values())) + + self._builds_init() + + @staticmethod + def valid_format(data_files: list[str]) -> bool: + if not data_files: + return False + + for data_file in data_files: + try: + with open(data_file, "r", encoding="utf-8") as fp: + for parsed in csv.DictReader(fp): + break + except Exception as e: + logger.debug(f'"{data_file}" is not a valid CSV file.', exc_info=e) + return False + + return True + + +bench_data = None +if len(input_file) == 1: + if LlamaBenchDataSQLite3File.valid_format(input_file[0]): + bench_data = LlamaBenchDataSQLite3File(input_file[0]) + elif LlamaBenchDataJSON.valid_format(input_file): + bench_data = LlamaBenchDataJSON(input_file) + elif LlamaBenchDataJSONL.valid_format(input_file[0]): + bench_data = LlamaBenchDataJSONL(input_file[0]) + elif LlamaBenchDataCSV.valid_format(input_file): + bench_data = LlamaBenchDataCSV(input_file) +else: + if LlamaBenchDataJSON.valid_format(input_file): + bench_data = LlamaBenchDataJSON(input_file) + elif LlamaBenchDataCSV.valid_format(input_file): + bench_data = LlamaBenchDataCSV(input_file) + +if not bench_data: + raise RuntimeError("No valid (or some invalid) input files found.") + +if not bench_data.builds: + raise RuntimeError(f"{input_file} does not contain any builds.") hexsha8_baseline = name_baseline = None # If the user specified a baseline, try to find a commit for it: if known_args.baseline is not None: - if known_args.baseline in builds: + if known_args.baseline in bench_data.builds: hexsha8_baseline = known_args.baseline if hexsha8_baseline is None: - hexsha8_baseline = get_commit_hexsha8(known_args.baseline) + hexsha8_baseline = bench_data.get_commit_hexsha8(known_args.baseline) name_baseline = known_args.baseline if hexsha8_baseline is None: logger.error(f"cannot find data for baseline={known_args.baseline}.") sys.exit(1) # Otherwise, search for the most recent parent of master for which there is data: -elif repo is not None: - hexsha8_baseline = find_parent_in_data(repo.heads.master.commit) +elif bench_data.repo is not None: + hexsha8_baseline = bench_data.find_parent_in_data(bench_data.repo.heads.master.commit) if hexsha8_baseline is None: logger.error("No baseline was provided and did not find data for any master branch commits.\n") @@ -235,27 +469,25 @@ else: sys.exit(1) -name_baseline = get_commit_name(hexsha8_baseline) +name_baseline = bench_data.get_commit_name(hexsha8_baseline) hexsha8_compare = name_compare = None # If the user has specified a compare value, try to find a corresponding commit: if known_args.compare is not None: - if known_args.compare in builds: + if known_args.compare in bench_data.builds: hexsha8_compare = known_args.compare if hexsha8_compare is None: - hexsha8_compare = get_commit_hexsha8(known_args.compare) + hexsha8_compare = bench_data.get_commit_hexsha8(known_args.compare) name_compare = known_args.compare if hexsha8_compare is None: logger.error(f"cannot find data for compare={known_args.compare}.") sys.exit(1) # Otherwise, search for the commit for llama-bench was most recently run # and that is not a parent of master: -elif repo is not None: - hexsha8s_master = get_all_parent_hexsha8s(repo.heads.master.commit) - builds_timestamp = cursor.execute( - "SELECT build_commit, test_time FROM test ORDER BY test_time;").fetchall() - for (hexsha8, _) in reversed(builds_timestamp): +elif bench_data.repo is not None: + hexsha8s_master = bench_data.get_all_parent_hexsha8s(bench_data.repo.heads.master.commit) + for (hexsha8, _) in bench_data.builds_timestamp(reverse=True): if hexsha8 not in hexsha8s_master: hexsha8_compare = hexsha8 break @@ -270,46 +502,26 @@ else: parser.print_help() sys.exit(1) -name_compare = get_commit_name(hexsha8_compare) - - -def get_rows(properties): - """ - Helper function that gets table rows for some list of properties. - Rows are created by combining those where all provided properties are equal. - The resulting rows are then grouped by the provided properties and the t/s values are averaged. - The returned rows are unique in terms of property combinations. - """ - select_string = ", ".join( - [f"tb.{p}" for p in properties] + ["tb.n_prompt", "tb.n_gen", "AVG(tb.avg_ts)", "AVG(tc.avg_ts)"]) - equal_string = " AND ".join( - [f"tb.{p} = tc.{p}" for p in KEY_PROPERTIES] + [ - f"tb.build_commit = '{hexsha8_baseline}'", f"tc.build_commit = '{hexsha8_compare}'"] - ) - group_order_string = ", ".join([f"tb.{p}" for p in properties] + ["tb.n_gen", "tb.n_prompt"]) - query = (f"SELECT {select_string} FROM test tb JOIN test tc ON {equal_string} " - f"GROUP BY {group_order_string} ORDER BY {group_order_string};") - return cursor.execute(query).fetchall() - +name_compare = bench_data.get_commit_name(hexsha8_compare) # If the user provided columns to group the results by, use them: if known_args.show is not None: show = known_args.show.split(",") unknown_cols = [] for prop in show: - if prop not in KEY_PROPERTIES[:-2]: # Last two values are n_prompt, n_gen. + if prop not in KEY_PROPERTIES[:-3]: # Last three values are n_prompt, n_gen, n_depth. unknown_cols.append(prop) if unknown_cols: logger.error(f"Unknown values for --show: {', '.join(unknown_cols)}") parser.print_usage() sys.exit(1) - rows_show = get_rows(show) + rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare) # Otherwise, select those columns where the values are not all the same: else: - rows_full = get_rows(KEY_PROPERTIES) + rows_full = bench_data.get_rows(KEY_PROPERTIES, hexsha8_baseline, hexsha8_compare) properties_different = [] for i, kp_i in enumerate(KEY_PROPERTIES): - if kp_i in DEFAULT_SHOW or kp_i == "n_prompt" or kp_i == "n_gen": + if kp_i in DEFAULT_SHOW or kp_i in ["n_prompt", "n_gen", "n_depth"]: continue for row_full in rows_full: if row_full[i] != rows_full[0][i]: @@ -318,7 +530,7 @@ else: show = [] # Show CPU and/or GPU by default even if the hardware for all results is the same: - if "n_gpu_layers" not in properties_different: + if rows_full and "n_gpu_layers" not in properties_different: ngl = int(rows_full[0][KEY_PROPERTIES.index("n_gpu_layers")]) if ngl != 99 and "cpu_info" not in properties_different: @@ -336,21 +548,36 @@ else: show.remove(prop) except ValueError: pass - rows_show = get_rows(show) + + # Add plot_x parameter to parameters to show if it's not already present: + if known_args.plot: + for k, v in PRETTY_NAMES.items(): + if v == known_args.plot_x and k not in show: + show.append(k) + break + + rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare) + +if not rows_show: + logger.error(f"No comparable data was found between {name_baseline} and {name_compare}.\n") + sys.exit(1) table = [] for row in rows_show: - n_prompt = int(row[-4]) - n_gen = int(row[-3]) + n_prompt = int(row[-5]) + n_gen = int(row[-4]) + n_depth = int(row[-3]) if n_prompt != 0 and n_gen == 0: test_name = f"pp{n_prompt}" elif n_prompt == 0 and n_gen != 0: test_name = f"tg{n_gen}" else: test_name = f"pp{n_prompt}+tg{n_gen}" + if n_depth != 0: + test_name = f"{test_name}@d{n_depth}" # Regular columns test name avg t/s values Speedup # VVVVVVVVVVVVV VVVVVVVVV VVVVVVVVVVVVVV VVVVVVV - table.append(list(row[:-4]) + [test_name] + list(row[-2:]) + [float(row[-1]) / float(row[-2])]) + table.append(list(row[:-5]) + [test_name] + list(row[-2:]) + [float(row[-1]) / float(row[-2])]) # Some a-posteriori fixes to make the table contents prettier: for bool_property in BOOL_PROPERTIES: @@ -376,7 +603,7 @@ if "gpu_info" in show: for gns in GPU_NAME_STRIP: row_table[ip] = row_table[ip].replace(gns, "") - gpu_names = row_table[ip].split("/") + gpu_names = row_table[ip].split(", ") num_gpus = len(gpu_names) all_names_the_same = len(set(gpu_names)) == 1 if len(gpu_names) >= 2 and all_names_the_same: @@ -385,6 +612,161 @@ if "gpu_info" in show: headers = [PRETTY_NAMES[p] for p in show] headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"] +if known_args.plot: + def create_performance_plot(table_data: list[list[str]], headers: list[str], baseline_name: str, compare_name: str, output_file: str, plot_x_param: str, log_scale: bool = False): + try: + import matplotlib.pyplot as plt + import matplotlib + matplotlib.use('Agg') + except ImportError as e: + logger.error("matplotlib is required for --plot.") + raise e + + data_headers = headers[:-4] # Exclude the last 4 columns (Test, baseline t/s, compare t/s, Speedup) + plot_x_index = None + plot_x_label = plot_x_param + + if plot_x_param not in ["n_prompt", "n_gen", "n_depth"]: + pretty_name = PRETTY_NAMES.get(plot_x_param, plot_x_param) + if pretty_name in data_headers: + plot_x_index = data_headers.index(pretty_name) + plot_x_label = pretty_name + elif plot_x_param in data_headers: + plot_x_index = data_headers.index(plot_x_param) + plot_x_label = plot_x_param + else: + logger.error(f"Parameter '{plot_x_param}' not found in current table columns. Available columns: {', '.join(data_headers)}") + return + + grouped_data = {} + + for i, row in enumerate(table_data): + group_key_parts = [] + test_name = row[-4] + + base_test = "" + x_value = None + + if plot_x_param in ["n_prompt", "n_gen", "n_depth"]: + for j, val in enumerate(row[:-4]): + header_name = data_headers[j] + if val is not None and str(val).strip(): + group_key_parts.append(f"{header_name}={val}") + + if plot_x_param == "n_prompt" and "pp" in test_name: + base_test = test_name.split("@")[0] + x_value = base_test + elif plot_x_param == "n_gen" and "tg" in test_name: + x_value = test_name.split("@")[0] + elif plot_x_param == "n_depth" and "@d" in test_name: + base_test = test_name.split("@d")[0] + x_value = int(test_name.split("@d")[1]) + else: + base_test = test_name + + if base_test.strip(): + group_key_parts.append(f"Test={base_test}") + else: + for j, val in enumerate(row[:-4]): + if j != plot_x_index: + header_name = data_headers[j] + if val is not None and str(val).strip(): + group_key_parts.append(f"{header_name}={val}") + else: + x_value = val + + group_key_parts.append(f"Test={test_name}") + + group_key = tuple(group_key_parts) + + if group_key not in grouped_data: + grouped_data[group_key] = [] + + grouped_data[group_key].append({ + 'x_value': x_value, + 'baseline': float(row[-3]), + 'compare': float(row[-2]), + 'speedup': float(row[-1]) + }) + + if not grouped_data: + logger.error("No data available for plotting") + return + + def make_axes(num_groups, max_cols=2, base_size=(8, 4)): + from math import ceil + cols = 1 if num_groups == 1 else min(max_cols, num_groups) + rows = ceil(num_groups / cols) + + # Scale figure size by grid dimensions + w, h = base_size + fig, ax_arr = plt.subplots(rows, cols, + figsize=(w * cols, h * rows), + squeeze=False) + + axes = ax_arr.flatten()[:num_groups] + return fig, axes + + num_groups = len(grouped_data) + fig, axes = make_axes(num_groups) + + plot_idx = 0 + + for group_key, points in grouped_data.items(): + if plot_idx >= len(axes): + break + ax = axes[plot_idx] + + try: + points_sorted = sorted(points, key=lambda p: float(p['x_value']) if p['x_value'] is not None else 0) + x_values = [float(p['x_value']) if p['x_value'] is not None else 0 for p in points_sorted] + except ValueError: + points_sorted = sorted(points, key=lambda p: group_key) + x_values = [p['x_value'] for p in points_sorted] + + baseline_vals = [p['baseline'] for p in points_sorted] + compare_vals = [p['compare'] for p in points_sorted] + + ax.plot(x_values, baseline_vals, 'o-', color='skyblue', + label=f'{baseline_name}', linewidth=2, markersize=6) + ax.plot(x_values, compare_vals, 's--', color='lightcoral', alpha=0.8, + label=f'{compare_name}', linewidth=2, markersize=6) + + if log_scale: + ax.set_xscale('log', base=2) + unique_x = sorted(set(x_values)) + ax.set_xticks(unique_x) + ax.set_xticklabels([str(int(x)) for x in unique_x]) + + title_parts = [] + for part in group_key: + if '=' in part: + key, value = part.split('=', 1) + title_parts.append(f"{key}: {value}") + + title = ', '.join(title_parts) if title_parts else "Performance comparison" + + ax.set_xlabel(plot_x_label, fontsize=12, fontweight='bold') + ax.set_ylabel('Tokens per second (t/s)', fontsize=12, fontweight='bold') + ax.set_title(title, fontsize=12, fontweight='bold') + ax.legend(loc='best', fontsize=10) + ax.grid(True, alpha=0.3) + + plot_idx += 1 + + for i in range(plot_idx, len(axes)): + axes[i].set_visible(False) + + fig.suptitle(f'Performance comparison: {compare_name} vs. {baseline_name}', + fontsize=14, fontweight='bold') + fig.subplots_adjust(top=1) + + plt.tight_layout() + plt.savefig(output_file, dpi=300, bbox_inches='tight') + plt.close() + + create_performance_plot(table, headers, name_baseline, name_compare, known_args.plot, known_args.plot_x, known_args.plot_log_scale) + print(tabulate( # noqa: NP100 table, headers=headers, diff --git a/scripts/fetch_server_test_models.py b/scripts/fetch_server_test_models.py index e6775bfc5..ac483ef5d 100755 --- a/scripts/fetch_server_test_models.py +++ b/scripts/fetch_server_test_models.py @@ -8,7 +8,7 @@ Example: python scripts/fetch_server_test_models.py - ( cd examples/server/tests && ./tests.sh -v -x -m slow ) + ( cd tools/server/tests && ./tests.sh -v -x -m slow ) ''' import ast import glob @@ -66,7 +66,7 @@ if __name__ == '__main__': models = sorted(list(set([ model - for test_file in glob.glob('examples/server/tests/unit/test_*.py') + for test_file in glob.glob('tools/server/tests/unit/test_*.py') for model in collect_hf_model_test_parameters(test_file) ])), key=lambda m: (m.hf_repo, m.hf_file)) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 7111936ba..bb5d56a0e 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -2abf606f098844faebee578996cae9c6d63a40e2 +8cda0a3c19f2c7dc493887353c42f6956bc268b1 diff --git a/scripts/sync_vendor.py b/scripts/sync_vendor.py new file mode 100755 index 000000000..1151c9f01 --- /dev/null +++ b/scripts/sync_vendor.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 + +import urllib.request + +vendor = { + "https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp", + "https://github.com/nlohmann/json/releases/latest/download/json_fwd.hpp": "vendor/nlohmann/json_fwd.hpp", + + # sync manually + # "https://raw.githubusercontent.com/ochafik/minja/refs/heads/main/include/minja/minja.hpp": "vendor/minja/minja.hpp", + # "https://raw.githubusercontent.com/ochafik/minja/refs/heads/main/include/minja/chat-template.hpp": "vendor/minja/chat-template.hpp", + + "https://raw.githubusercontent.com/nothings/stb/refs/heads/master/stb_image.h": "vendor/stb/stb_image.h", + + "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.22/miniaudio.h": "vendor/miniaudio/miniaudio.h", + + "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.20.1/httplib.h": "vendor/cpp-httplib/httplib.h", +} + +for url, filename in vendor.items(): + print(f"downloading {url} to {filename}") # noqa: NP100 + urllib.request.urlretrieve(url, filename) diff --git a/scripts/tool_bench.py b/scripts/tool_bench.py index 0f406bc42..d8018e2e2 100755 --- a/scripts/tool_bench.py +++ b/scripts/tool_bench.py @@ -2,7 +2,7 @@ ''' Simplistic tool call benchmarks for llama-server and ollama. - Essentially runs the tests at server/examples/server/tests/unit/test_tool_call.py N times, at different temperatures and on different backends (current llama-server, baseline llama-server and ollama), + Essentially runs the tests at server/tools/server/tests/unit/test_tool_call.py N times, at different temperatures and on different backends (current llama-server, baseline llama-server and ollama), and plots the results of multiple runs (from same .jsonl file or multiple ones) as a success rate heatmap. Simple usage example: @@ -12,6 +12,7 @@ export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server export LLAMA_CACHE=${LLAMA_CACHE:-$HOME/Library/Caches/llama.cpp} + ./scripts/tool_bench.py run --n 10 --temp -1 --temp 0 --temp 1 --temp 2 --temp 5 --llama-baseline $PWD/buildMaster/bin/llama-server --output qwen14b.jsonl --hf bartowski/Qwen2.5-14B-Instruct-GGUF:Q4_K_L ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 1.5B Q4_K_M" --output qwen1.5b.jsonl --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 Coder 7B Q4_K_M" --output qwenc7b.jsonl --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b @@ -51,8 +52,8 @@ import typer sys.path.insert(0, Path(__file__).parent.parent.as_posix()) if True: - from examples.server.tests.utils import ServerProcess - from examples.server.tests.unit.test_tool_call import TIMEOUT_SERVER_START, do_test_calc_result, do_test_hello_world, do_test_weather + from tools.server.tests.utils import ServerProcess + from tools.server.tests.unit.test_tool_call import TIMEOUT_SERVER_START, do_test_calc_result, do_test_hello_world, do_test_weather @contextmanager @@ -205,6 +206,7 @@ def run( model: Annotated[Optional[str], typer.Option(help="Name of the model to test (server agnostic)")] = None, hf: Annotated[Optional[str], typer.Option(help="GGUF huggingface model repo id (+ optional quant) to test w/ llama-server")] = None, chat_template: Annotated[Optional[str], typer.Option(help="Chat template override for llama-server")] = None, + chat_template_file: Annotated[Optional[str], typer.Option(help="Chat template file override for llama-server")] = None, ollama: Annotated[Optional[str], typer.Option(help="Ollama model tag to test")] = None, llama_baseline: Annotated[Optional[str], typer.Option(help="llama-server baseline binary path to use as baseline")] = None, n: Annotated[int, typer.Option(help="Number of times to run each test")] = 10, @@ -229,6 +231,12 @@ def run( # n_ctx = 8192 n_ctx = 2048 + if model is None: + if hf is not None: + model = hf.split("/")[-1] + elif ollama is not None: + model = ollama + assert force or append or not output.exists(), f"Output file already exists: {output}; use --force to overwrite" with output.open('a' if append else 'w') as output_file: @@ -320,6 +328,7 @@ def run( server.model_hf_repo = hf server.model_hf_file = None server.chat_template = chat_template + server.chat_template_file = chat_template_file server.server_path = server_path if port is not None: server.server_port = port @@ -335,6 +344,7 @@ def run( temp=t, output_kwargs=dict( chat_template=chat_template, + chat_template_file=chat_template_file, ), request_kwargs=dict( ignore_chat_grammar=ignore_chat_grammar, @@ -355,6 +365,7 @@ def run( temp=t, output_kwargs=dict( chat_template=None, + chat_template_file=None, ), request_kwargs=dict( model=ollama, diff --git a/scripts/xxd.cmake b/scripts/xxd.cmake index f5ad6ab9b..14d275380 100644 --- a/scripts/xxd.cmake +++ b/scripts/xxd.cmake @@ -1,5 +1,5 @@ # CMake equivalent of `xxd -i ${INPUT} ${OUTPUT}` -# Usage: cmake -DINPUT=examples/server/public/index.html -DOUTPUT=examples/server/index.html.hpp -P scripts/xxd.cmake +# Usage: cmake -DINPUT=tools/server/public/index.html -DOUTPUT=tools/server/index.html.hpp -P scripts/xxd.cmake SET(INPUT "" CACHE STRING "Input File") SET(OUTPUT "" CACHE STRING "Output File") diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 9f7ab13f1..70be604e4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -14,15 +14,19 @@ add_library(llama llama-batch.cpp llama-chat.cpp llama-context.cpp + llama-cparams.cpp llama-grammar.cpp llama-graph.cpp llama-hparams.cpp llama-impl.cpp llama-io.cpp - llama-kv-cache.cpp + llama-kv-cache-unified.cpp + llama-kv-cache-unified-iswa.cpp + llama-kv-cache-recurrent.cpp llama-memory.cpp llama-mmap.cpp llama-model-loader.cpp + llama-model-saver.cpp llama-model.cpp llama-quant.cpp llama-sampling.cpp @@ -32,8 +36,9 @@ add_library(llama unicode.h ) -target_include_directories(llama PUBLIC . ../include) -target_compile_features (llama PUBLIC cxx_std_17) # don't bump +target_include_directories(llama PRIVATE .) +target_include_directories(llama PUBLIC ../include) +target_compile_features (llama PRIVATE cxx_std_17) # don't bump target_link_libraries(llama PUBLIC ggml) diff --git a/src/llama-adapter.cpp b/src/llama-adapter.cpp index 7ac54d239..8d94034ae 100644 --- a/src/llama-adapter.cpp +++ b/src/llama-adapter.cpp @@ -253,6 +253,9 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ std::vector buft_extra; { auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!cpu_dev) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) @@ -291,6 +294,9 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft)); auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!cpu_dev) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } buft = ggml_backend_dev_buffer_type(cpu_dev); break; diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index a6fddc7fd..de8d289cf 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -19,6 +19,8 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_REFACT, "refact" }, { LLM_ARCH_BERT, "bert" }, { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, + { LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" }, + { LLM_ARCH_NEO_BERT, "neo-bert" }, { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, { LLM_ARCH_BLOOM, "bloom" }, { LLM_ARCH_STABLELM, "stablelm" }, @@ -71,6 +73,8 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, { LLM_ARCH_PLM, "plm" }, { LLM_ARCH_BAILINGMOE, "bailingmoe" }, + { LLM_ARCH_DOTS1, "dots1" }, + { LLM_ARCH_ARCEE, "arcee" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -106,6 +110,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, { LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" }, { LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" }, + { LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" }, { LLM_KV_POOLING_TYPE, "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, @@ -140,6 +145,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, + { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, + { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, @@ -170,6 +177,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_CONVNEXT_EMBEDDING_LENGTH, "%s.convnext.embedding_length" }, { LLM_KV_CONVNEXT_BLOCK_COUNT, "%s.convnext.block_count" }, + { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" }, + { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, @@ -194,7 +203,6 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, { LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" }, - { LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" }, { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, @@ -238,6 +246,24 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_ARCEE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_LLAMA4, { @@ -444,6 +470,7 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_TOKEN_TYPES, "token_types" }, { LLM_TENSOR_POS_EMBD, "position_embd" }, { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, @@ -470,6 +497,39 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_NOMIC_BERT_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_TOKEN_TYPES, "token_types" }, + { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_NEO_BERT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" }, + { LLM_TENSOR_CLS, "cls" }, + { LLM_TENSOR_CLS_OUT, "cls.output" }, + }, + }, { LLM_ARCH_JINA_BERT_V2, { @@ -1103,6 +1163,8 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, { LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" }, + { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" }, + { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, @@ -1457,6 +1519,9 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, }, }, { @@ -1526,6 +1591,34 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, }, }, + { + LLM_ARCH_DOTS1, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + } + }, { LLM_ARCH_UNKNOWN, { @@ -1563,23 +1656,8 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_DOWN_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_GATE_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_FFN_UP_SHEXP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_Q_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_KV_A_MQA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, @@ -1692,8 +1770,14 @@ static const std::map LLM_TENSOR_INFOS = { LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} std::string LLM_KV::operator()(llm_kv kv) const { - return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix) - : ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch)); + std::string name = ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch)); + + if (suffix != nullptr) { + name += "."; + name += suffix; + } + + return name; } std::string LLM_TN_IMPL::str() const { diff --git a/src/llama-arch.h b/src/llama-arch.h index 2c2099b3c..3e8a61da3 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -23,6 +23,8 @@ enum llm_arch { LLM_ARCH_REFACT, LLM_ARCH_BERT, LLM_ARCH_NOMIC_BERT, + LLM_ARCH_NOMIC_BERT_MOE, + LLM_ARCH_NEO_BERT, LLM_ARCH_JINA_BERT_V2, LLM_ARCH_BLOOM, LLM_ARCH_STABLELM, @@ -75,6 +77,8 @@ enum llm_arch { LLM_ARCH_WAVTOKENIZER_DEC, LLM_ARCH_PLM, LLM_ARCH_BAILINGMOE, + LLM_ARCH_DOTS1, + LLM_ARCH_ARCEE, LLM_ARCH_UNKNOWN, }; @@ -110,6 +114,7 @@ enum llm_kv { LLM_KV_EXPERT_WEIGHTS_SCALE, LLM_KV_EXPERT_WEIGHTS_NORM, LLM_KV_EXPERT_GATING_FUNC, + LLM_KV_MOE_EVERY_N_LAYERS, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, LLM_KV_DECODER_START_TOKEN_ID, @@ -144,6 +149,8 @@ enum llm_kv { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, LLM_KV_ATTENTION_SLIDING_WINDOW, LLM_KV_ATTENTION_SCALE, + LLM_KV_ATTENTION_KEY_LENGTH_MLA, + LLM_KV_ATTENTION_VALUE_LENGTH_MLA, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS, @@ -192,7 +199,6 @@ enum llm_kv { LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_RWKV, LLM_KV_TOKENIZER_CHAT_TEMPLATE, - LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, LLM_KV_TOKENIZER_FIM_PRE_ID, LLM_KV_TOKENIZER_FIM_SUF_ID, LLM_KV_TOKENIZER_FIM_MID_ID, @@ -209,6 +215,8 @@ enum llm_kv { LLM_KV_CONVNEXT_EMBEDDING_LENGTH, LLM_KV_CONVNEXT_BLOCK_COUNT, + LLM_KV_CLASSIFIER_OUTPUT_LABELS, + // deprecated: LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID, @@ -306,6 +314,8 @@ enum llm_tensor { LLM_TENSOR_ATTN_Q_B, LLM_TENSOR_ATTN_KV_A_MQA, LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, LLM_TENSOR_ATTN_Q_A_NORM, LLM_TENSOR_ATTN_KV_A_NORM, LLM_TENSOR_ATTN_SUB_NORM, diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 01d5ca57f..8b6d14fe8 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -1,7 +1,14 @@ #include "llama-batch.h" +#include "llama-impl.h" +#include "llama-cparams.h" +#include "llama-vocab.h" +#include "llama-memory.h" + +#include #include #include +#include llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) { // clear empty sequences @@ -14,24 +21,31 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) { break; } } - ubatch_token.resize(!has_embd ? n_ubatch : 0); - ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0); - ubatch_pos.resize(n_ubatch); - ubatch_n_seq_id.resize(n_ubatch); - ubatch_seq_id.resize(n_ubatch); - ubatch_output.resize(n_ubatch); + + udatas.push_back({}); + + auto & udata = udatas.back(); + + udata.token.resize(!has_embd ? n_ubatch : 0); + udata.embd.resize(has_embd ? n_embd * n_ubatch : 0); + udata.pos.resize(n_ubatch); + udata.n_seq_id.resize(n_ubatch); + udata.seq_id.resize(n_ubatch); + udata.output.resize(n_ubatch); + llama_ubatch ubatch = { /*equal_seqs =*/ true, /*n_tokens =*/ 0, /*n_seq_tokens =*/ 0, /*n_seqs =*/ 0, - /*token =*/ !has_embd ? ubatch_token.data() : nullptr, - /*embd =*/ has_embd ? ubatch_embd.data() : nullptr, - /*pos =*/ ubatch_pos.data(), - /*n_seq_id =*/ ubatch_n_seq_id.data(), - /*seq_id =*/ ubatch_seq_id.data(), - /*output =*/ ubatch_output.data(), + /*token =*/ !has_embd ? udata.token.data() : nullptr, + /*embd =*/ has_embd ? udata.embd.data() : nullptr, + /*pos =*/ udata.pos.data(), + /*n_seq_id =*/ udata.n_seq_id.data(), + /*seq_id =*/ udata.seq_id.data(), + /*output =*/ udata.output.data(), }; + return ubatch; } @@ -97,12 +111,7 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s ubatch.seq_id = batch->seq_id + seq.offset; } } - if (logits_all) { - for (size_t i = 0; i < length; ++i) { - ubatch.output[ubatch.n_tokens + i] = 1; - out_ids.push_back(ids[seq.offset + i]); - } - } else if (batch->logits) { + if (batch->logits) { if (ubatch.equal_seqs) { for (size_t i = 0; i < length; ++i) { size_t id = ids[seq.offset + i]; @@ -189,11 +198,10 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) { return ubatch; } -void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) { +llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) { GGML_ASSERT(batch.n_tokens >= 0); this->batch = &batch; this->n_embd = n_embd; - this->logits_all = logits_all; n_tokens = batch.n_tokens; ids.resize(n_tokens); @@ -203,6 +211,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim for (size_t i = 0; i < n_tokens; ++i) { ids[i] = i; } + if (simple_split) { seq.resize(1); llama_sbatch_seq & s = seq[0]; @@ -212,6 +221,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim s.length = n_tokens; return; } + std::sort(ids.begin(), ids.end(), [&batch](size_t a, size_t b) { int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1; @@ -239,6 +249,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim return n_seq_a > n_seq_b; } ); + // init seq llama_sbatch_seq * last_seq = nullptr; @@ -262,6 +273,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim seq.push_back(new_seq); last_seq = &seq.back(); } + // keep shared prompts first at the end, then sort by length descending. std::sort(seq.begin(), seq.end(), [](llama_sbatch_seq & a, llama_sbatch_seq & b) { @@ -273,16 +285,56 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim ); } -llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) { - batch = in_batch; - GGML_ASSERT(batch.n_tokens > 0); - if (!batch.pos) { - pos.resize(batch.n_tokens); - for (int32_t i = 0; i < batch.n_tokens; i++) { - pos[i] = i + p0; - } - batch.pos = pos.data(); +llama_batch_allocr::llama_batch_allocr() { + const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG"); + debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0; + + seq_pos.resize(LLAMA_MAX_SEQ); + seq_cpl.resize(LLAMA_MAX_SEQ); + for (auto & cur : seq_cpl) { + cur.resize(LLAMA_MAX_SEQ); } +} + +bool llama_batch_allocr::init( + const llama_batch & batch_inp, + const llama_vocab & vocab, + const llama_memory_i * memory, + bool embd_all) { + clear(); + + batch = batch_inp; + + GGML_ASSERT(batch.n_tokens > 0); + + // + // validate input batch + // + + if (batch.token) { + for (int32_t i = 0; i < batch.n_tokens; ++i) { + if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) { + LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); + return false; + } + } + } + + if (batch.seq_id) { + for (int32_t i = 0; i < batch.n_tokens; ++i) { + for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) { + if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) { + LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ); + return false; + } + } + } + } + + // + // auto-generate missing fields + // + if (!batch.n_seq_id) { n_seq_id.resize(batch.n_tokens); for (int32_t i = 0; i < batch.n_tokens; i++) { @@ -290,6 +342,7 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0 } batch.n_seq_id = n_seq_id.data(); } + if (!batch.seq_id) { seq_id.resize(batch.n_tokens + 1); seq_id[batch.n_tokens] = NULL; @@ -298,10 +351,221 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0 } batch.seq_id = seq_id.data(); } + + if (!batch.pos) { + pos.resize(batch.n_tokens); + + // initialize the starting position for each sequence based on the positions in the memory + llama_pos p0[LLAMA_MAX_SEQ]; + for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + if (!memory) { + p0[s] = 0; + } else { + p0[s] = memory->seq_pos_max(s) + 1; + } + } + + for (int32_t i = 0; i < batch.n_tokens; i++) { + const llama_seq_id seq_id = batch.seq_id[i][0]; + + pos[i] = p0[seq_id]; + + for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) { + p0[batch.seq_id[i][s]] = pos[i] + 1; + } + } + + batch.pos = pos.data(); + } + if (!batch.logits) { - logits.resize(batch.n_tokens); - logits[logits.size() - 1] = true; - batch.logits = logits.data(); + if (embd_all) { + // return the output for all tokens + output.resize(batch.n_tokens, true); + } else { + // return the output only for the last token + output.resize(batch.n_tokens, false); + output[output.size() - 1] = true; + } + + batch.logits = output.data(); + } else if (embd_all) { + bool warn = false; + + for (int32_t i = 0; i < batch.n_tokens; ++i) { + if (batch.logits[i] == 0) { + warn = true; + } + } + + if (warn) { + LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__); + + output.resize(batch.n_tokens, true); + batch.logits = output.data(); + } + } + + // + // compute stats + // + + for (int32_t i = 0; i < batch.n_tokens; ++i) { + n_outputs += batch.logits[i] != 0; + } + + // determine coupled sequences + // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them + for (int32_t i = 0; i < batch.n_tokens; ++i) { + for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) { + seq_pos[batch.seq_id[i][s]].insert(batch.pos[i]); + + if (s > 0) { + const llama_seq_id s0 = batch.seq_id[i][0]; + const llama_seq_id s1 = batch.seq_id[i][s]; + + // mark that sequence s1 is coupled to s0 + seq_cpl[s1][s0] = true; + + // note: the other way around is not necessary for now + //seq_cpl[s0][s1] = true; + } + } + } + + if (debug > 0) { + LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__); + LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, batch.n_tokens); + LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) batch.token); + LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) batch.embd); + LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) batch.pos); + LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) batch.n_seq_id); + LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) batch.seq_id); + LLAMA_LOG_DEBUG("%s: logits = %p\n", __func__, (void *) batch.logits); + LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs); + + if (debug > 1) { + int seq_id_max = 0; + for (int32_t i = 0; i < batch.n_tokens; ++i) { + for (int s = 0; s < batch.n_seq_id[i]; ++s) { + for (int s = 0; s < batch.n_seq_id[i]; ++s) { + seq_id_max = std::max(seq_id_max, batch.seq_id[i][s]); + } + } + } + ++seq_id_max; + + LLAMA_LOG_DEBUG("%s: token = [\n", __func__); + for (int32_t i = 0; i < batch.n_tokens; ++i) { + std::vector seq_id(seq_id_max); + + for (int s = 0; s < batch.n_seq_id[i]; ++s) { + seq_id[batch.seq_id[i][s]] = 1; + } + + std::stringstream ss; + for (int s = 0; s < seq_id_max; ++s) { + if (seq_id[s]) { + ss << s%10; + } else { + ss << "."; + } + } + + LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n", + __func__, i, batch.token[i], vocab.token_to_piece(batch.token[i]).c_str(), + batch.pos[i], batch.n_seq_id[i], ss.str().c_str(), batch.logits[i]); + } + LLAMA_LOG_DEBUG("%s: ]\n", __func__); + + LLAMA_LOG_DEBUG("%s: seq = [\n", __func__); + for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) { + if (seq_pos[s0].empty()) { + continue; + } + + std::stringstream ss; + for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) { + if (seq_cpl[s0][s1]) { + ss << s1 << " "; + } + } + + LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n", + __func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str()); + } + LLAMA_LOG_DEBUG("%s: ]\n", __func__); + } + } + + // + // consistency checks + // + + for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + if (seq_pos[s].empty()) { + continue; + } + + if (memory && seq_pos_min(s) != memory->seq_pos_max(s) + 1) { + LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s); + return false; + } + + if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) { + LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s); + return false; + } + } + + if (memory) { + for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) { + for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) { + if (seq_cpl[s0][s1]) { + if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) || + memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) { + LLAMA_LOG_ERROR("%s: sequence %d is coupled to %d in the input batch, but have divereged\n", __func__, s0, s1); + return false; + } + } + } + } + } + + return true; +} + +const llama_batch & llama_batch_allocr::get_batch() const { + return batch; +} + +uint32_t llama_batch_allocr::get_n_outputs() const { + return n_outputs; +} + +llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const { + return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin(); +} + +llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const { + return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin(); +} + +void llama_batch_allocr::clear() { + n_outputs = 0; + + batch = {}; + pos.clear(); + n_seq_id.clear(); + seq_id.clear(); + output.clear(); + + for (auto & cur : seq_pos) { + cur.clear(); + } + + for (auto & cur : seq_cpl) { + std::fill(cur.begin(), cur.end(), false); } } diff --git a/src/llama-batch.h b/src/llama-batch.h index f1df40d27..a555c1572 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -4,6 +4,7 @@ #include #include +#include // very similar to llama_batch, // but has more metadata about sequences @@ -11,7 +12,7 @@ struct llama_ubatch { bool equal_seqs; // TODO: whole_seqs for embeddings? - uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) + uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) uint32_t n_seq_tokens; // tokens per sequence uint32_t n_seqs; @@ -39,8 +40,6 @@ struct llama_sbatch { size_t n_embd; - bool logits_all; // TODO: remove once lctx.logits_all is removed too - // sorted indices into the batch std::vector ids; // batch indices of the output @@ -49,13 +48,18 @@ struct llama_sbatch { const llama_batch * batch = nullptr; - // buffers for the ubatch - std::vector ubatch_token; - std::vector ubatch_embd; - std::vector ubatch_pos; - std::vector ubatch_n_seq_id; - std::vector ubatch_seq_id; - std::vector ubatch_output; + // buffers for the ubatches + // TODO: very hacky, this needs a complete rework + struct ubatch_data { + std::vector token; + std::vector embd; + std::vector pos; + std::vector n_seq_id; + std::vector seq_id; + std::vector output; + }; + + std::vector udatas; llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false); @@ -70,19 +74,46 @@ struct llama_sbatch { // sequence-wise split llama_ubatch split_seq(size_t n_ubatch); - void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false); + llama_sbatch() = default; + llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false); }; -// temporary allocate memory for the input batch if needed -struct llama_batch_allocr { - struct llama_batch batch; +// a helper for sanitizing and fulfilling a batch +class llama_batch_allocr { +public: + llama_batch_allocr(); + + // sanitize and auto-gen missing data in the input batch + // memory is optional. if provided will be used to check for sequence continuity and to determine the positions + bool init( + const llama_batch & batch_inp, + const llama_vocab & vocab, + const llama_memory_i * memory, + bool embd_all); + + const llama_batch & get_batch() const; + + uint32_t get_n_outputs() const; + + llama_pos seq_pos_min(llama_seq_id seq_id) const; + llama_pos seq_pos_max(llama_seq_id seq_id) const; + +private: + void clear(); + + llama_batch batch; + + uint32_t n_outputs; std::array seq_id_0 = { 0 }; // default sequence id + std::vector pos; std::vector n_seq_id; std::vector seq_id; - std::vector logits; + std::vector output; - // optionally fulfill the batch returned by llama_batch_get_one - llama_batch_allocr(struct llama_batch in_batch, llama_pos p0); + std::vector> seq_pos; // seq_pos[s]: the set of positions in sequence s + std::vector> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1 + + int debug; }; diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 721faa4e8..0839cad3e 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -35,6 +35,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 }, { "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN }, { "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 }, + { "mistral-v7-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN }, { "phi3", LLM_CHAT_TEMPLATE_PHI_3 }, { "phi4", LLM_CHAT_TEMPLATE_PHI_4 }, { "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 }, @@ -50,8 +51,8 @@ static const std::map LLM_CHAT_TEMPLATES = { { "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 }, { "command-r", LLM_CHAT_TEMPLATE_COMMAND_R }, { "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 }, - { "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 }, - { "chatglm4", LLM_CHAT_TEMPLATE_CHATGML_4 }, + { "chatglm3", LLM_CHAT_TEMPLATE_CHATGLM_3 }, + { "chatglm4", LLM_CHAT_TEMPLATE_CHATGLM_4 }, { "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE }, { "minicpm", LLM_CHAT_TEMPLATE_MINICPM }, { "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 }, @@ -62,6 +63,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "yandex", LLM_CHAT_TEMPLATE_YANDEX }, { "bailing", LLM_CHAT_TEMPLATE_BAILING }, { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, + { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM }, }; llm_chat_template llm_chat_template_from_str(const std::string & name) { @@ -81,7 +83,9 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { if (tmpl_contains("<|im_start|>")) { return tmpl_contains("<|im_sep|>") ? LLM_CHAT_TEMPLATE_PHI_4 - : LLM_CHAT_TEMPLATE_CHATML; + : tmpl_contains("") + ? LLM_CHAT_TEMPLATE_SMOLVLM // SmolVLM uses <|im_start|> as BOS, but it is NOT chatml + : LLM_CHAT_TEMPLATE_CHATML; } else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) { if (tmpl_contains("[SYSTEM_PROMPT]")) { return LLM_CHAT_TEMPLATE_MISTRAL_V7; @@ -119,8 +123,12 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { } } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) { return LLM_CHAT_TEMPLATE_PHI_3; + } else if (tmpl_contains("[gMASK]")) { + return LLM_CHAT_TEMPLATE_CHATGLM_4; } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) { return tmpl_contains("") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE; + } else if (tmpl_contains("<|{{ item['role'] }}|>") && tmpl_contains("<|begin_of_image|>")) { + return LLM_CHAT_TEMPLATE_GLMEDGE; } else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) { return LLM_CHAT_TEMPLATE_ZEPHYR; } else if (tmpl_contains("bos_token + message['role']")) { @@ -149,9 +157,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_LLAMA_3; } else if (tmpl_contains("[gMASK]sop")) { // chatglm3-6b - return LLM_CHAT_TEMPLATE_CHATGML_3; - } else if (tmpl_contains("[gMASK]")) { - return LLM_CHAT_TEMPLATE_CHATGML_4; + return LLM_CHAT_TEMPLATE_CHATGLM_3; } else if (tmpl_contains(LU8("<用户>"))) { // MiniCPM-3B-OpenHermes-2.5-v2-GGUF return LLM_CHAT_TEMPLATE_MINICPM; @@ -177,6 +183,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_BAILING; } else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) { return LLM_CHAT_TEMPLATE_LLAMA4; + } else if (tmpl_contains("<|endofuserprompt|>")) { + return LLM_CHAT_TEMPLATE_DOTS1; } return LLM_CHAT_TEMPLATE_UNKNOWN; } @@ -197,19 +205,20 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|im_start|>assistant\n"; } - } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7) { + } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN) { // Official mistral 'v7' template // See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7 + // https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503#basic-instruct-template-v7-tekken + const char * trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 ? " " : ""; for (auto message : chat) { std::string role(message->role); std::string content(message->content); if (role == "system") { - ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]"; + ss << "[SYSTEM_PROMPT]" << trailing_space << content << "[/SYSTEM_PROMPT]"; } else if (role == "user") { - ss << "[INST] " << content << "[/INST]"; - } - else { - ss << " " << content << ""; + ss << "[INST]" << trailing_space << content << "[/INST]"; + } else { + ss << trailing_space << content << ""; } } } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1 @@ -324,7 +333,7 @@ int32_t llm_chat_apply_template( std::string role(message->role); if (role == "system") { // there is no system message for gemma, but we will merge it with user prompt, so nothing is broken - system_prompt = trim(message->content); + system_prompt += trim(message->content); continue; } // in gemma, "assistant" is "model" @@ -346,7 +355,7 @@ int32_t llm_chat_apply_template( std::string role(message->role); if (role == "system") { // there is no system message support, we will merge it with user prompt - system_prompt = message->content; + system_prompt += message->content; continue; } else if (role == "user") { ss << "Human: "; @@ -432,7 +441,7 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; } - } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_3) { + } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_3) { // chatglm3-6b ss << "[gMASK]" << "sop"; for (auto message : chat) { @@ -442,14 +451,14 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|assistant|>"; } - } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_4) { + } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_4) { ss << "[gMASK]" << ""; for (auto message : chat) { std::string role(message->role); ss << "<|" << role << "|>" << "\n" << message->content; } if (add_ass) { - ss << "<|assistant|>"; + ss << "<|assistant|>\n"; } } else if (tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) { for (auto message : chat) { @@ -620,7 +629,38 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|header_start|>assistant<|header_end|>\n\n"; } - } else { + } else if (tmpl == LLM_CHAT_TEMPLATE_SMOLVLM) { + // SmolVLM + ss << "<|im_start|>"; // uses <|im_start|> as BOS, but the actual content is NOT chatml + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << message->content << "\n\n"; + } else if (role == "user") { + ss << "User: " << message->content << "\n"; + } else { + ss << "Assistant: " << message->content << "\n"; + } + } + if (add_ass) { + ss << "Assistant:"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_DOTS1) { + // dots.llm1.inst (DOTS1) + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "<|system|>" << message->content << "<|endofsystem|>"; + } else if (role == "user") { + ss << "<|userprompt|>" << message->content << "<|endofuserprompt|>"; + } else { + ss << "<|response|>" << message->content << "<|endofresponse|>"; + } + } + if (add_ass) { + ss << "<|response|>"; + } + } else { // template not supported return -1; } diff --git a/src/llama-chat.h b/src/llama-chat.h index 34537ca21..38800010a 100644 --- a/src/llama-chat.h +++ b/src/llama-chat.h @@ -14,6 +14,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_MISTRAL_V3, LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN, LLM_CHAT_TEMPLATE_MISTRAL_V7, + LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN, LLM_CHAT_TEMPLATE_PHI_3, LLM_CHAT_TEMPLATE_PHI_4, LLM_CHAT_TEMPLATE_FALCON_3, @@ -29,8 +30,8 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_DEEPSEEK_3, LLM_CHAT_TEMPLATE_COMMAND_R, LLM_CHAT_TEMPLATE_LLAMA_3, - LLM_CHAT_TEMPLATE_CHATGML_3, - LLM_CHAT_TEMPLATE_CHATGML_4, + LLM_CHAT_TEMPLATE_CHATGLM_3, + LLM_CHAT_TEMPLATE_CHATGLM_4, LLM_CHAT_TEMPLATE_GLMEDGE, LLM_CHAT_TEMPLATE_MINICPM, LLM_CHAT_TEMPLATE_EXAONE_3, @@ -41,6 +42,8 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_YANDEX, LLM_CHAT_TEMPLATE_BAILING, LLM_CHAT_TEMPLATE_LLAMA4, + LLM_CHAT_TEMPLATE_SMOLVLM, + LLM_CHAT_TEMPLATE_DOTS1, LLM_CHAT_TEMPLATE_UNKNOWN, }; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 4735e98ea..f56a58e9b 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1,15 +1,16 @@ #include "llama-context.h" #include "llama-impl.h" +#include "llama-batch.h" #include "llama-io.h" +#include "llama-memory.h" #include "llama-mmap.h" #include "llama-model.h" -#include "llama-kv-cache.h" -#include -#include -#include #include +#include +#include +#include // // llama_context @@ -18,7 +19,8 @@ llama_context::llama_context( const llama_model & model, llama_context_params params) : - model(model) { + model(model), + batch_allocr(std::make_unique()) { LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__); t_start_us = model.t_start_us; @@ -26,7 +28,11 @@ llama_context::llama_context( const auto & hparams = model.hparams; - cparams.n_seq_max = std::max(1u, params.n_seq_max); + cparams.n_seq_max = std::max(1u, params.n_seq_max); + if (cparams.n_seq_max > LLAMA_MAX_SEQ) { + throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ)); + } + cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor; @@ -95,6 +101,8 @@ llama_context::llama_context( cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); + cparams.op_offload = params.op_offload; + const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; LLAMA_LOG_INFO("%s: n_seq_max = %u\n", __func__, cparams.n_seq_max); @@ -113,11 +121,14 @@ llama_context::llama_context( } if (n_ctx_per_seq > hparams.n_ctx_train) { - LLAMA_LOG_WARN("%s: n_ctx_pre_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n", + LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n", __func__, n_ctx_per_seq, hparams.n_ctx_train); } - logits_all = params.logits_all; + if (!params.swa_full && cparams.n_seq_max > 1 && hparams.is_swa_any()) { + LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n", + __func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573"); + } if (!hparams.vocab_only) { // GPU backends @@ -176,44 +187,14 @@ llama_context::llama_context( } // init the memory module - // TODO: for now, always create a unified KV cache if (!hparams.vocab_only) { - kv_self.reset(static_cast(model.create_memory())); + llama_memory_params params_mem = { + /*.type_k =*/ params.type_k, + /*.type_v =*/ params.type_v, + /*.swa_full =*/ params.swa_full, + }; - LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx); - - cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams)); - - LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); - - uint32_t kv_size = cparams.n_ctx; - ggml_type type_k = params.type_k; - ggml_type type_v = params.type_v; - - if (llama_model_is_recurrent(&model)) { - // Mamba needs at least as many KV cells as there are sequences kept at any time - kv_size = std::max((uint32_t) 1, params.n_seq_max); - // it's probably best to keep as much precision as possible for the states - type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states - type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states - } - - GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0); - GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0); - - if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) { - throw std::runtime_error("failed to initialize self-attention cache"); - } - - { - const size_t memory_size_k = kv_self->size_k_bytes(); - const size_t memory_size_v = kv_self->size_v_bytes(); - - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, - (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), - ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), - ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); - } + memory.reset(model.create_memory(params_mem, cparams)); } // init backends @@ -277,7 +258,7 @@ llama_context::llama_context( } } - sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel)); + sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload)); if (pipeline_parallel) { LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get())); @@ -285,16 +266,10 @@ llama_context::llama_context( } // reserve worst-case graph - if (!hparams.vocab_only) { - const uint32_t n_seqs = 1; // TODO: worst-case number of sequences + if (!hparams.vocab_only && memory) { + const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - - // restore later - // TODO: something cleaner - const auto n_outputs_save = n_outputs; - LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); int n_splits_pp = -1; @@ -304,23 +279,18 @@ llama_context::llama_context( int n_nodes_tg = -1; // simulate full KV cache - kv_self->n = kv_self->size; + + const auto mstate = memory->init_full(); + if (!mstate) { + throw std::runtime_error("failed to initialize KV cache"); + } cross.v_embd.clear(); // reserve pp graph first so that buffers are only allocated once { - llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - - // max number of outputs - n_outputs = ubatch_pp.n_tokens; - - LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs); - - auto * gf = graph_init(); - graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT); - - if (!ggml_backend_sched_reserve(sched.get(), gf)) { + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get()); + if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } @@ -330,16 +300,8 @@ llama_context::llama_context( // reserve with tg graph to get the number of splits and nodes { - llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - - n_outputs = ubatch_tg.n_tokens; - - LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs); - - auto * gf = graph_init(); - graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT); - - if (!ggml_backend_sched_reserve(sched.get(), gf)) { + auto * gf = graph_reserve(1, 1, 1, mstate.get()); + if (!gf) { throw std::runtime_error("failed to allocate compute tg buffers"); } @@ -349,22 +311,12 @@ llama_context::llama_context( // reserve again with pp graph to avoid ggml-alloc reallocations during inference { - llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - - n_outputs = ubatch_pp.n_tokens; - - LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs); - - auto * gf = graph_init(); - graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT); - - if (!ggml_backend_sched_reserve(sched.get(), gf)) { + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get()); + if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } } - n_outputs = n_outputs_save; - for (size_t i = 0; i < backend_ptrs.size(); ++i) { ggml_backend_t backend = backend_ptrs[i]; ggml_backend_buffer_type_t buft = backend_buft[i]; @@ -390,7 +342,9 @@ llama_context::llama_context( } } -llama_context::~llama_context() = default; +llama_context::~llama_context() { + ggml_opt_free(opt_ctx); +} void llama_context::synchronize() { ggml_backend_sched_synchronize(sched.get()); @@ -426,6 +380,18 @@ const llama_model & llama_context::get_model() const { return model; } +const llama_cparams & llama_context::get_cparams() const { + return cparams; +} + +ggml_backend_sched_t llama_context::get_sched() const { + return sched.get(); +} + +ggml_context * llama_context::get_ctx_compute() const { + return ctx_compute.get(); +} + uint32_t llama_context::n_ctx() const { return cparams.n_ctx; } @@ -454,370 +420,71 @@ uint32_t llama_context::n_threads_batch() const { return cparams.n_threads_batch; } -llama_kv_cache * llama_context::get_kv_self() { - return kv_self.get(); +llama_memory_t llama_context::get_memory() const { + return memory.get(); } -const llama_kv_cache * llama_context::get_kv_self() const { - return kv_self.get(); +// deprecated +void llama_context::kv_self_defrag_sched() { + if (!memory) { + return; + } + + memory_force_optimize = true; } -ggml_tensor * llama_context::build_rope_shift( - ggml_context * ctx0, - ggml_tensor * cur, - ggml_tensor * shift, - ggml_tensor * factors, - float freq_base, - float freq_scale, - ggml_backend_buffer * bbuf) const { - const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; +// deprecated +bool llama_context::kv_self_update(bool optimize) { + if (!memory) { + return false; + } - const auto & yarn_ext_factor = cparams.yarn_ext_factor; - const auto & yarn_attn_factor = cparams.yarn_attn_factor; - const auto & yarn_beta_fast = cparams.yarn_beta_fast; - const auto & yarn_beta_slow = cparams.yarn_beta_slow; + { + // TODO: remove in the future + optimize |= memory_force_optimize; + memory_force_optimize = false; - const auto & hparams = model.hparams; - - const auto & n_rot = hparams.n_rot; - const auto & rope_type = hparams.rope_type; - - ggml_tensor * tmp; - - if (ggml_is_quantized(cur->type)) { - // dequantize to f32 -> RoPE -> quantize back - tmp = ggml_cast(ctx0, cur, GGML_TYPE_F32); - - if (bbuf) { - for (const auto & backend : backends) { - // Figure out which backend KV cache belongs to - if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(bbuf))) { - ggml_backend_sched_set_tensor_backend(sched.get(), tmp, backend.get()); - break; + const auto mstate = memory->init_update(this, optimize); + switch (mstate->get_status()) { + case LLAMA_MEMORY_STATUS_SUCCESS: + { + // noop + } break; + case LLAMA_MEMORY_STATUS_NO_UPDATE: + { + // no updates need to be performed + return false; } - } - } - - tmp = ggml_rope_ext_inplace(ctx0, tmp, - shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); - - tmp = ggml_cpy(ctx0, tmp, cur); - } else { - // we rotate only the first n_rot dimensions - tmp = ggml_rope_ext_inplace(ctx0, cur, - shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); - } - - return tmp; -} - -class llm_graph_input_k_shift : public llm_graph_input_i { -public: - llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} - virtual ~llm_graph_input_k_shift() = default; - - void set_input(const llama_ubatch * ubatch) override; - - ggml_tensor * k_shift; // I32 [kv_size] - - const llama_kv_cache_unified * kv_self; -}; - -void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { - GGML_UNUSED(ubatch); - - if (k_shift) { - assert(ggml_backend_buffer_is_host(k_shift->buffer)); - - int32_t * data = (int32_t *) k_shift->data; - - for (uint32_t i = 0; i < kv_self->size; ++i) { - data[i] = kv_self->cells[i].delta; - } - } -} - -llm_graph_result_ptr llama_context::build_kv_self_shift( - ggml_context * ctx0, - ggml_cgraph * gf) const { - auto res = std::make_unique(); - - const auto & hparams = model.hparams; - - const auto & n_layer = hparams.n_layer; - - const auto & n_embd_head_k = hparams.n_embd_head_k; - //const auto & n_embd_head_v = hparams.n_embd_head_v; - - //GGML_ASSERT(kv_self->size == n_ctx); - - auto inp = std::make_unique(kv_self.get()); - - inp->k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_ctx); - ggml_set_input(inp->k_shift); - - for (uint32_t il = 0; il < n_layer; ++il) { - const int64_t n_head_kv = hparams.n_head_kv(il); - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - - const bool is_swa = hparams.is_swa(il); - - // note: the swa rope params could become part of the cparams in the future - // if we decide to make them configurable, like the non-sliding ones - const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base; - const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale; - - ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il); - - ggml_tensor * k = - ggml_view_3d(ctx0, kv_self->k_l[il], - n_embd_head_k, n_head_kv, kv_self->size, - ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k), - ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa), - 0); - - ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer); - - ggml_build_forward_expand(gf, cur); - } - - res->add_input(std::move(inp)); - - return res; -} - -llm_graph_result_ptr llama_context::build_kv_self_defrag( - ggml_context * ctx0, - ggml_cgraph * gf) const { - auto res = std::make_unique(); - - const auto & hparams = model.hparams; - - const auto & ids = kv_self->defrag_info.ids; - -#if 0 - // CPU defrag - // - // TODO: optimizations are possible: - // - multiple threads - // - avoid copying to the host memory when already there - // - // likely not worth the effort, as we have ggml_graph based defrag - // - - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - - const uint32_t kv_size = size; - - std::vector buf_k; - std::vector buf_v; - - for (uint32_t il = 0; il < n_layer; ++il) { - const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); - const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size); - - const size_t v_size_el = ggml_type_size(v_l[il]->type); - const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size); - - buf_k.resize(k_size); - buf_v.resize(v_size); - - ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size()); - ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size()); - - // batch move [i, i+nm) to [id, id+nm) - // note: cells can move only to a lower index - for (uint32_t i = 0; i < n_kv; ++i) { - const uint32_t id = ids[i]; - - if (i == id || id == n_kv) { - continue; - } - - uint32_t nm = 1; - - while (i + nm < n_kv && ids[i + nm] == id + nm) { - nm++; - } - - // move keys - { - const int64_t os = i*k_size_row; - const int64_t od = id*k_size_row; - - memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row); - } - - // move values (note: they are transposed) - { - const int64_t os = i; - const int64_t od = id; - - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el); + case LLAMA_MEMORY_STATUS_FAILED_PREPARE: + case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: + { + LLAMA_LOG_ERROR("%s: failed to prepare memory update\n", __func__); + return false; } - } - - i += nm - 1; } - ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size()); - ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size()); - } -#else - for (uint32_t i = 0; i < ids.size(); ++i) { - const uint32_t id = ids[i]; - - if (i == id || id == ids.size()) { - continue; - } - - uint32_t nm = 1; - - while (i + nm < ids.size() && ids[i + nm] == id + nm) { - nm++; - } - - for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); - - ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self->k_l[il], - n_embd_k_gqa, nm, - ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa), - ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i)); - - ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self->k_l[il], - n_embd_k_gqa, nm, - ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa), - ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id)); - - ggml_tensor * view_v_src; - ggml_tensor * view_v_dst; - - if (cparams.flash_attn) { - // NOTE: the V cache is not transposed when using flash attention - view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il], - n_embd_v_gqa, nm, - ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa), - ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i)); - - view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il], - n_embd_v_gqa, nm, - ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa), - ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id)); - } else { - view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il], - nm, n_embd_v_gqa, - ggml_row_size(kv_self->v_l[il]->type, kv_self->size), - ggml_row_size(kv_self->v_l[il]->type, i)); - - view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il], - nm, n_embd_v_gqa, - ggml_row_size(kv_self->v_l[il]->type, kv_self->size), - ggml_row_size(kv_self->v_l[il]->type, id)); - } - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst)); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst)); - } - - i += nm - 1; - } - - //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes); -#endif - - return res; -} - -void llama_context::kv_self_update() { - auto & kv = kv_self; - - bool need_reserve = false; - - if (kv->has_shift) { - if (!kv->get_can_shift()) { - GGML_ABORT("The current context does not support K-shift"); - } - - LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__); - - // apply K-shift if needed - if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { - ggml_backend_sched_reset(sched.get()); - - auto * gf = graph_init(); - - auto res = build_kv_self_shift(ctx_compute.get(), gf); - - ggml_backend_sched_alloc_graph(sched.get(), gf); - - res->set_inputs(nullptr); - - graph_compute(gf, false); - - need_reserve = true; - } - - { - kv->has_shift = false; - - for (uint32_t i = 0; i < kv->size; ++i) { - kv->cells[i].delta = 0; - } + if (!mstate->apply()) { + LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__); } } - // defragment the KV cache if needed - if (kv->do_defrag) { - LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); - - if (kv->defrag_prepare(graph_max_nodes())) { - ggml_backend_sched_reset(sched.get()); - - auto * gf = graph_init(); - - auto res = build_kv_self_defrag(ctx_compute.get(), gf); - - ggml_backend_sched_alloc_graph(sched.get(), gf); - - res->set_inputs(nullptr); - - graph_compute(gf, false); - - need_reserve = true; + // if the memory module did any computation, we have to reserve a new worst-case graph + { + const auto mstate = memory->init_full(); + if (!mstate) { + throw std::runtime_error("failed to initialize memory state"); } - kv->do_defrag = false; - } + const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - // reserve a worst case graph if needed - if (need_reserve) { - LLAMA_LOG_DEBUG("%s: reserving a worst case graph\n", __func__); - - // build worst-case graph - uint32_t n_seqs = 1; // TODO: worst-case number of sequences - uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - - // simulate full KV cache - kv_self->n = kv_self->size; - - llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; - - auto * gf = graph_init(); - graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT); - - // initialize scheduler with the worst-case graph - ggml_backend_sched_reset(sched.get()); - if (!ggml_backend_sched_reserve(sched.get(), gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mstate.get()); + if (!gf) { + LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__); } } + + return true; } enum llama_pooling_type llama_context::pooling_type() const { @@ -825,14 +492,11 @@ enum llama_pooling_type llama_context::pooling_type() const { } float * llama_context::get_logits() { - // reorder logits for backward compatibility - output_reorder(); - return logits; } float * llama_context::get_logits_ith(int32_t i) { - int32_t j = -1; + int64_t j = -1; try { if (logits == nullptr) { @@ -855,7 +519,7 @@ float * llama_context::get_logits_ith(int32_t i) { } if (j >= n_outputs) { // This should not happen - throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs)); + throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); } return logits + j*model.vocab.n_tokens(); @@ -870,14 +534,11 @@ float * llama_context::get_logits_ith(int32_t i) { } float * llama_context::get_embeddings() { - // reorder embeddings for backward compatibility - output_reorder(); - return embd; } float * llama_context::get_embeddings_ith(int32_t i) { - int32_t j = -1; + int64_t j = -1; try { if (embd == nullptr) { @@ -900,7 +561,7 @@ float * llama_context::get_embeddings_ith(int32_t i) { } if (j >= n_outputs) { // This should not happen - throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, n_outputs)); + throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); } return embd + j*model.hparams.n_embd; @@ -1017,44 +678,84 @@ bool llama_context::apply_adapter_cvec( return cvec.apply(model, data, len, n_embd, il_start, il_end); } -int llama_context::encode(llama_batch & inp_batch) { - if (inp_batch.n_tokens == 0) { +llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_state_i * mstate, ggml_status & ret) { + if (mstate && !mstate->apply()) { + LLAMA_LOG_ERROR("%s: failed to apply memory state\n", __func__); + ret = GGML_STATUS_FAILED; + return nullptr; + } + + auto * gf = graph_init(); + if (!gf) { + LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); + ret = GGML_STATUS_FAILED; + return nullptr; + } + + auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mstate); + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__); + ret = GGML_STATUS_FAILED; + return nullptr; + } + + // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); + + if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); + ret = GGML_STATUS_ALLOC_FAILED; + return nullptr; + } + + res->set_inputs(&ubatch); + + const auto status = graph_compute(gf, ubatch.n_tokens > 1); + if (status != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status); + ret = status; + return nullptr; + } + + ret = GGML_STATUS_SUCCESS; + + return res; +} + +int llama_context::encode(const llama_batch & batch_inp) { + if (batch_inp.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); return -1; } - // temporary allocate memory for the input batch if needed - // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); + // note: during encode, we always pass the full sequence starting from pos = 0 + if (!batch_allocr->init(batch_inp, model.vocab, nullptr, true)) { + LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); + return -1; + } - const llama_batch & batch = batch_allocr.batch; - const int32_t n_tokens = batch.n_tokens; + const llama_batch & batch = batch_allocr->get_batch(); - const auto & hparams = model.hparams; + const uint32_t n_tokens = batch.n_tokens; GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT - if (batch.token) { - for (int32_t i = 0; i < n_tokens; ++i) { - if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { - LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); - return -1; - } - } - } - // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot - GGML_ASSERT(cparams.n_ubatch >= (uint32_t) n_tokens && "encoder requires n_ubatch >= n_tokens"); + GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens"); if (t_compute_start_us == 0) { t_compute_start_us = ggml_time_us(); } + // TODO: this clear of the buffer can easily be forgotten - need something better + embd_seq.clear(); + n_queued_tokens += n_tokens; + const auto & hparams = model.hparams; + const int64_t n_embd = hparams.n_embd; - sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true); + llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true); const llama_ubatch ubatch = sbatch.split_simple(n_tokens); @@ -1064,14 +765,12 @@ int llama_context::encode(llama_batch & inp_batch) { return -2; }; - for (int32_t i = 0; i < n_tokens; ++i) { + for (uint32_t i = 0; i < n_tokens; ++i) { output_ids[i] = i; } n_outputs = n_tokens; - //batch_manager->prepare(ubatch); - ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); @@ -1082,26 +781,18 @@ int llama_context::encode(llama_batch & inp_batch) { // ref: https://github.com/ggml-org/llama.cpp/pull/12181#issuecomment-2730451223 cparams.causal_attn = false; - auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_ENCODER); - - ggml_backend_sched_alloc_graph(sched.get(), gf); - - res->set_inputs(&ubatch); + ggml_status status; + const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); cparams.causal_attn = causal_attn_org; - const auto compute_status = graph_compute(gf, n_tokens > 1); - switch (compute_status) { - case GGML_STATUS_SUCCESS: - break; - case GGML_STATUS_ABORTED: - return 2; - case GGML_STATUS_ALLOC_FAILED: - return -2; - case GGML_STATUS_FAILED: - default: - return -3; + if (!res) { + switch (status) { + case GGML_STATUS_ABORTED: return 2; + case GGML_STATUS_ALLOC_FAILED: return -2; + case GGML_STATUS_FAILED: return -3; + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); + } } auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); @@ -1111,12 +802,12 @@ int llama_context::encode(llama_batch & inp_batch) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); - GGML_ASSERT(embd != nullptr); - switch (cparams.pooling_type) { case LLAMA_POOLING_TYPE_NONE: { // extract token embeddings + GGML_ASSERT(embd != nullptr); + GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size); ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float)); } break; @@ -1130,7 +821,8 @@ int llama_context::encode(llama_batch & inp_batch) { GGML_ASSERT(!ubatch.equal_seqs); // TODO: handle equal splits - for (int32_t i = 0; i < n_tokens; i++) { + // TODO: fix indexing [UBATCH_IDX] + for (uint32_t i = 0; i < n_tokens; i++) { const llama_seq_id seq_id = ubatch.seq_id[i][0]; if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { continue; @@ -1141,11 +833,20 @@ int llama_context::encode(llama_batch & inp_batch) { } break; case LLAMA_POOLING_TYPE_RANK: { - // TODO: this likely should be the same logic as in llama_decoder_internal, but better to - // wait for an encoder model that requires this pooling type in order to test it - // https://github.com/ggerganov/llama.cpp/pull/9510 - GGML_ABORT("RANK pooling not implemented yet"); - } + // extract the rerank score - n_cls_out floats per sequence + auto & embd_seq_out = embd_seq; + const uint32_t n_cls_out = hparams.n_cls_out; + + // TODO: fix indexing [UBATCH_IDX] + for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(n_cls_out); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_id)*sizeof(float), n_cls_out*sizeof(float)); + } + } break; case LLAMA_POOLING_TYPE_UNSPECIFIED: { GGML_ABORT("unknown pooling type"); @@ -1170,10 +871,10 @@ int llama_context::encode(llama_batch & inp_batch) { // remember the sequence ids used during the encoding - needed for cross attention later cross.seq_ids_enc.resize(n_tokens); - for (int32_t i = 0; i < n_tokens; i++) { + for (uint32_t i = 0; i < n_tokens; i++) { cross.seq_ids_enc[i].clear(); - for (int s = 0; s < ubatch.n_seq_id[i]; s++) { - llama_seq_id seq_id = ubatch.seq_id[i][s]; + for (int s = 0; s < batch.n_seq_id[i]; s++) { + llama_seq_id seq_id = batch.seq_id[i][s]; cross.seq_ids_enc[i].insert(seq_id); } } @@ -1182,36 +883,45 @@ int llama_context::encode(llama_batch & inp_batch) { return 0; } -int llama_context::decode(llama_batch & inp_batch) { - if (inp_batch.n_tokens == 0) { +int llama_context::decode(const llama_batch & batch_inp) { + if (!memory) { + LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); + return encode(batch_inp); + } + + if (batch_inp.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); return -1; } - // temporary allocate memory for the input batch if needed - // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); + // when computing embeddings, all tokens are output + const bool embd_all = cparams.embeddings; - const llama_batch & batch = batch_allocr.batch; + if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) { + LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); + return -1; + } + + const llama_batch & batch = batch_allocr->get_batch(); const auto & vocab = model.vocab; const auto & hparams = model.hparams; const int32_t n_vocab = vocab.n_tokens(); + const int64_t n_embd = hparams.n_embd; - const int64_t n_tokens_all = batch.n_tokens; - const int64_t n_embd = hparams.n_embd; - - llama_kv_cache_guard kv_guard(kv_self.get()); + const uint32_t n_tokens_all = batch.n_tokens; GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT - if (batch.token) { - for (int64_t i = 0; i < n_tokens_all; ++i) { - if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { - LLAMA_LOG_ERROR("%s: invalid token[%" PRId64 "] = %d\n", __func__, i, batch.token[i]); - throw std::runtime_error("invalid token"); - } + const uint32_t n_outputs_all = batch_allocr->get_n_outputs(); + + if (embd_all) { + // require that all tokens are output + if (n_outputs_all != n_tokens_all) { + LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n", + __func__, n_outputs_all, n_tokens_all); + return -1; } } @@ -1224,61 +934,71 @@ int llama_context::decode(llama_batch & inp_batch) { } n_queued_tokens += n_tokens_all; - // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens - const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; - + // TODO: this clear of the buffer can easily be forgotten - need something better embd_seq.clear(); - int64_t n_outputs_all = 0; + bool did_optimize = false; - // count outputs - if (batch.logits && !embd_pooled) { - for (uint32_t i = 0; i < n_tokens_all; ++i) { - n_outputs_all += batch.logits[i] != 0; + // handle any pending defrags/shifts + kv_self_update(false); + + llama_memory_state_ptr mstate; + + while (true) { + mstate = memory->init_batch(batch, cparams.n_ubatch, embd_all); + if (!mstate) { + return -2; } - } else if (logits_all || embd_pooled) { - n_outputs_all = n_tokens_all; - } else { - // keep last output only - n_outputs_all = 1; + + switch (mstate->get_status()) { + case LLAMA_MEMORY_STATUS_SUCCESS: + { + } break; + case LLAMA_MEMORY_STATUS_NO_UPDATE: + { + LLAMA_LOG_ERROR("%s: unexpected memory state status: %d\n", __func__, mstate->get_status()); + + return -2; + } + case LLAMA_MEMORY_STATUS_FAILED_PREPARE: + { + if (!did_optimize) { + did_optimize = true; + + if (kv_self_update(true)) { + LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, batch.n_tokens); + + continue; + } + } + + LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, batch.n_tokens); + + return 1; + } + case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: + { + LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, batch.n_tokens); + + return -2; + } + } + + break; } - const bool logits_all = n_outputs_all == n_tokens_all; - - sbatch.from_batch(batch, n_embd, - /* simple_split */ !kv_self->recurrent, - /* logits_all */ logits_all); - // reserve output buffer if (output_reserve(n_outputs_all) < n_outputs_all) { - LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all); + LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); return -2; }; - // handle any pending defrags/shifts - kv_self_update(); - int64_t n_outputs_prev = 0; - while (sbatch.n_tokens > 0) { - llama_ubatch ubatch = llama_ubatch(); + do { + const auto & ubatch = mstate->get_ubatch(); - const auto & n_ubatch = cparams.n_ubatch; - - if (kv_self->recurrent) { - if (embd_pooled) { - // Pooled embeddings cannot be split across ubatches (yet) - ubatch = sbatch.split_seq(cparams.n_ubatch); - } else { - // recurrent model architectures are easier to implement - // with equal-length sequences - ubatch = sbatch.split_equal(cparams.n_ubatch); - } - } else { - ubatch = sbatch.split_simple(n_ubatch); - } - - // count the outputs in this u_batch + // count the outputs in this ubatch { int32_t n_outputs_new = 0; @@ -1295,47 +1015,41 @@ int llama_context::decode(llama_batch & inp_batch) { n_outputs = n_outputs_new; } - // find KV slot - { - if (!kv_self->find_slot(ubatch)) { - LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens); - - return 1; - } - - if (!kv_self->recurrent) { - // a heuristic, to avoid attending the full cache if it is not yet utilized - // after enough generations, the benefit from this heuristic disappears - // if we start defragmenting the cache, the benefit from this will be more important - const uint32_t pad = kv_self->get_padding(cparams); - kv_self->n = std::min(kv_self->size, std::max(pad, GGML_PAD(kv_self->cell_max(), pad))); - } - } - - //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head); - ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); - auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DECODER); + ggml_status status; + const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mstate.get(), status); - // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); + if (!res) { + // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache + llama_pos pos_min[LLAMA_MAX_SEQ]; + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + pos_min[s] = std::numeric_limits::max(); + } - ggml_backend_sched_alloc_graph(sched.get(), gf); + // TODO: fix sequence indexing + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + const auto & seq_id = ubatch.seq_id[i][0]; - res->set_inputs(&ubatch); + pos_min[seq_id] = std::min(pos_min[seq_id], ubatch.pos[i]); + } - const auto compute_status = graph_compute(gf, ubatch.n_tokens > 1); - if (compute_status != GGML_STATUS_SUCCESS) { - switch (compute_status) { - case GGML_STATUS_ABORTED: - return 2; - case GGML_STATUS_ALLOC_FAILED: - return -2; - case GGML_STATUS_FAILED: - default: - return -3; + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + if (pos_min[s] == std::numeric_limits::max()) { + continue; + } + + LLAMA_LOG_WARN("%s: removing KV cache entries for seq_id = %d, pos = [%d, +inf)\n", __func__, s, pos_min[s]); + + memory->seq_rm(s, pos_min[s], -1); + } + + switch (status) { + case GGML_STATUS_ABORTED: return 2; + case GGML_STATUS_ALLOC_FAILED: return -2; + case GGML_STATUS_FAILED: return -3; + case GGML_STATUS_SUCCESS: GGML_ABORT("should not happen"); } } @@ -1344,7 +1058,7 @@ int llama_context::decode(llama_batch & inp_batch) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} - auto * t_logits = cparams.embeddings ? nullptr : res->get_logits(); + auto * t_logits = res->get_logits(); auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; if (t_embd && res->get_embd_pooled()) { @@ -1422,50 +1136,71 @@ int llama_context::decode(llama_batch & inp_batch) { } n_outputs_prev += n_outputs; - } + } while (mstate->next()); - // finalize the batch processing - kv_guard.commit(); + // set to total number of outputs in the batch, for use in llama_get_logits_ith + n_outputs = n_outputs_all; // set output mappings - { + if (n_outputs > 0) { bool sorted_output = true; - GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all); + auto & out_ids = mstate->out_ids(); - for (int64_t i = 0; i < n_outputs_all; ++i) { - int64_t out_id = sbatch.out_ids[i]; + GGML_ASSERT(out_ids.size() == (size_t) n_outputs); + + for (int64_t i = 0; i < n_outputs; ++i) { + int64_t out_id = out_ids[i]; output_ids[out_id] = i; if (out_id != i) { sorted_output = false; } } - if (sorted_output) { - sbatch.out_ids.clear(); + // make the outputs have the same order they had in the user-provided batch + // note: this is mostly relevant for recurrent models atm + if (!sorted_output) { + const uint32_t n_vocab = model.vocab.n_tokens(); + const uint64_t n_embd = model.hparams.n_embd; + + GGML_ASSERT((size_t) n_outputs == out_ids.size()); + + // TODO: is there something more efficient which also minimizes swaps? + // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) + for (uint32_t i = 0; i < n_outputs - 1; ++i) { + uint32_t j_min = i; + for (uint32_t j = i + 1; j < n_outputs; ++j) { + if (out_ids[j] < out_ids[j_min]) { + j_min = j; + } + } + if (j_min == i) { + continue; + } + std::swap(out_ids[i], out_ids[j_min]); + if (logits_size > 0) { + for (uint32_t k = 0; k < n_vocab; k++) { + std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]); + } + } + if (embd_size > 0) { + for (uint32_t k = 0; k < n_embd; k++) { + std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]); + } + } + } + + std::fill(output_ids.begin(), output_ids.end(), -1); + + for (uint32_t i = 0; i < n_outputs; ++i) { + output_ids[out_ids[i]] = i; + } } } - // set to total number of outputs in the batch, for use in llama_get_logits_ith - n_outputs = n_outputs_all; - // wait for the computation to finish (automatically done when obtaining the model output) //synchronize(); - // decide if we need to defrag the kv cache - if (cparams.causal_attn && cparams.defrag_thold > 0.0f) { - // - do not defrag small contexts (i.e. < 2048 tokens) - // - count the padding towards the number of used tokens - const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f; - - // queue defragmentation for next llama_kv_cache_update - if (fragmentation > cparams.defrag_thold) { - LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); - - kv_self->defrag(); - } - } - // Reset state for the next token before backend sync, to allow the CPU activities in the reset to // overlap with device computation. ggml_backend_sched_reset(sched.get()); @@ -1477,7 +1212,7 @@ int llama_context::decode(llama_batch & inp_batch) { // output // -int32_t llama_context::output_reserve(int32_t n_outputs) { +uint32_t llama_context::output_reserve(int32_t n_outputs) { const auto & hparams = model.hparams; const auto & vocab = model.vocab; @@ -1487,9 +1222,8 @@ int32_t llama_context::output_reserve(int32_t n_outputs) { const auto n_vocab = vocab.n_tokens(); const auto n_embd = hparams.n_embd; - // TODO: use a per-batch flag for logits presence instead - bool has_logits = !cparams.embeddings; - bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); + bool has_logits = true; + bool has_embd = cparams.embeddings; // TODO: hacky enc-dec support if (model.arch == LLM_ARCH_T5) { @@ -1543,52 +1277,11 @@ int32_t llama_context::output_reserve(int32_t n_outputs) { // set all ids as invalid (negative) std::fill(output_ids.begin(), output_ids.end(), -1); - ggml_backend_buffer_clear(buf_output.get(), 0); - - this->n_outputs = 0; - this->n_outputs_max = n_outputs_max; + this->n_outputs = 0; return n_outputs_max; } -void llama_context::output_reorder() { - auto & out_ids = sbatch.out_ids; - if (!out_ids.empty()) { - const uint32_t n_vocab = model.vocab.n_tokens(); - const uint32_t n_embd = model.hparams.n_embd; - - GGML_ASSERT((size_t) n_outputs == out_ids.size()); - - // TODO: is there something more efficient which also minimizes swaps? - // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) - for (int32_t i = 0; i < n_outputs - 1; ++i) { - int32_t j_min = i; - for (int32_t j = i + 1; j < n_outputs; ++j) { - if (out_ids[j] < out_ids[j_min]) { - j_min = j; - } - } - if (j_min == i) { continue; } - std::swap(out_ids[i], out_ids[j_min]); - if (logits_size > 0) { - for (uint32_t k = 0; k < n_vocab; k++) { - std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]); - } - } - if (embd_size > 0) { - for (uint32_t k = 0; k < n_embd; k++) { - std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]); - } - } - } - std::fill(output_ids.begin(), output_ids.end(), -1); - for (int32_t i = 0; i < n_outputs; ++i) { - output_ids[out_ids[i]] = i; - } - out_ids.clear(); - } -} - // // graph // @@ -1609,11 +1302,52 @@ ggml_cgraph * llama_context::graph_init() { return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false); } +ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate) { + LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs); + + if (n_tokens % n_seqs != 0) { + n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs + n_outputs = std::min(n_outputs, n_tokens); + + LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs); + } + + // store the n_outputs as it is, and restore it afterwards + // TODO: not sure if needed, might simplify in the future by removing this + const auto save_n_outputs = this->n_outputs; + + this->n_outputs = n_outputs; + + llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph + llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + + auto * gf = graph_init(); + auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate); + + this->n_outputs = save_n_outputs; + + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__); + return nullptr; + } + + ggml_backend_sched_reset(sched.get()); + + // initialize scheduler with the specified graph + if (!ggml_backend_sched_reserve(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); + return nullptr; + } + + return gf; +} + llm_graph_result_ptr llama_context::graph_build( - ggml_context * ctx, - ggml_cgraph * gf, - const llama_ubatch & ubatch, - llm_graph_type gtype) { + ggml_context * ctx, + ggml_cgraph * gf, + const llama_ubatch & ubatch, + llm_graph_type gtype, + const llama_memory_state_i * mstate) { return model.build_graph( { /*.ctx =*/ ctx, @@ -1625,7 +1359,7 @@ llm_graph_result_ptr llama_context::graph_build( /*.backend_cpu =*/ backend_cpu, /*.cvec =*/ &cvec, /*.loras =*/ &loras, - /*.memory =*/ kv_self.get(), + /*.mstate =*/ mstate, /*.cross =*/ &cross, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), @@ -2029,21 +1763,17 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { { LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__); - output_reorder(); - const auto n_outputs = this->n_outputs; const auto & output_ids = this->output_ids; std::vector w_output_pos; - GGML_ASSERT(n_outputs <= n_outputs_max); - w_output_pos.resize(n_outputs); // build a more compact representation of the output ids for (size_t i = 0; i < n_batch(); ++i) { // map an output id to a position in the batch - int32_t pos = output_ids[i]; + int64_t pos = output_ids[i]; if (pos >= 0) { GGML_ASSERT(pos < n_outputs); w_output_pos[pos] = i; @@ -2083,8 +1813,10 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { } } - LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__); - kv_self->state_write(io); + if (memory != nullptr) { + LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__); + memory->state_write(io); + } return io.n_bytes(); } @@ -2167,8 +1899,11 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { } } - LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__); - kv_self->state_read(io); + if (memory) { + LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__); + + memory->state_read(io); + } return io.n_bytes(); } @@ -2176,7 +1911,9 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) { GGML_UNUSED(seq_id); - kv_self->state_write(io, seq_id); + if (memory) { + memory->state_write(io, seq_id); + } return io.n_bytes(); } @@ -2184,7 +1921,9 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) { GGML_UNUSED(seq_id); - kv_self->state_read(io, seq_id); + if (memory) { + memory->state_read(io, seq_id); + } return io.n_bytes(); } @@ -2212,6 +1951,213 @@ void llama_context::perf_reset() { t_p_eval_us = n_p_eval = 0; } +// +// training +// + +static void llama_set_param(struct ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) { + if (!tensor || tensor->type != GGML_TYPE_F32) { + return; + } + if (!param_filter(tensor, userdata)) { + return; + } + if (strcmp(tensor->name, "token_embd.weight") == 0) { + return; // FIXME + } + if (strcmp(tensor->name, "rope_freqs.weight") == 0) { + return; // FIXME + } + ggml_set_param(tensor); +} + +void llama_context::opt_init(struct llama_model * model, struct llama_opt_params lopt_params) { + GGML_ASSERT(!opt_ctx); + model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : n_ctx(); + const uint32_t n_batch = std::min(this->n_batch(), model->hparams.n_ctx_train); + const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch); + GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0); + GGML_ASSERT(n_batch % n_ubatch == 0); + + ggml_opt_params opt_params = ggml_opt_default_params(sched.get(), GGML_OPT_LOSS_TYPE_CROSS_ENTROPY); + opt_params.opt_period = n_batch / n_ubatch; + opt_params.get_opt_pars = lopt_params.get_opt_pars; + opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud; + + opt_ctx = ggml_opt_init(opt_params); + + llama_opt_param_filter param_filter = lopt_params.param_filter; + void * param_filter_ud = lopt_params.param_filter_ud; + + //llama_set_param(model->tok_embd, param_filter, param_filter_ud); // FIXME + llama_set_param(model->type_embd, param_filter, param_filter_ud); + llama_set_param(model->pos_embd, param_filter, param_filter_ud); + llama_set_param(model->tok_norm, param_filter, param_filter_ud); + llama_set_param(model->tok_norm_b, param_filter, param_filter_ud); + llama_set_param(model->output_norm, param_filter, param_filter_ud); + llama_set_param(model->output_norm_b, param_filter, param_filter_ud); + llama_set_param(model->output, param_filter, param_filter_ud); + llama_set_param(model->output_b, param_filter, param_filter_ud); + llama_set_param(model->output_norm_enc, param_filter, param_filter_ud); + llama_set_param(model->cls, param_filter, param_filter_ud); + llama_set_param(model->cls_b, param_filter, param_filter_ud); + llama_set_param(model->cls_out, param_filter, param_filter_ud); + llama_set_param(model->cls_out_b, param_filter, param_filter_ud); + + for (struct llama_layer & layer : model->layers) { + for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) { + llama_set_param(reinterpret_cast(&layer)[i], param_filter, param_filter_ud); + } + } +} + +void llama_context::opt_epoch_iter( + ggml_opt_dataset_t dataset, + ggml_opt_result_t result, + const std::vector & tokens, + const std::vector & labels_sparse, + llama_batch & batch, + ggml_opt_epoch_callback callback, + bool train, + int64_t idata_in_loop, + int64_t ndata_in_loop, + int64_t t_loop_start) { + GGML_ASSERT(opt_ctx); + const uint32_t n_ctx = llama_model_n_ctx_train(&model); + const uint32_t n_batch = std::min(this->n_batch(), n_ctx); + const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch); + + memory->clear(true); + + for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) { + batch.n_tokens = n_batch; + for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) { + batch.token [pos_batch] = tokens[pos_ctx + pos_batch]; + batch.pos [pos_batch] = pos_ctx + pos_batch; + batch.n_seq_id[pos_batch] = 1; + batch.seq_id [pos_batch][0] = 0; + batch.logits [pos_batch] = true; + } + + const auto n_tokens_all = batch.n_tokens; + + n_queued_tokens += n_tokens_all; + + embd_seq.clear(); + + uint32_t n_outputs_all = n_tokens_all; + + auto mstate = memory->init_batch(batch, cparams.n_ubatch, true); + if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__); + break; + } + + // reserve output buffer + if (output_reserve(n_outputs_all) < n_outputs_all) { + LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); + GGML_ABORT("TODO: handle this error"); + }; + + uint32_t pos_batch = 0; + do { + const auto & ubatch = mstate->get_ubatch(); + + n_outputs = ubatch.n_tokens; + + if (!mstate->apply()) { + LLAMA_LOG_ERROR("%s: failed to update the memory state\n", __func__); + break; + } + + auto * gf = graph_init(); + auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate.get()); + + struct ggml_context * ctx_compute_opt; + { + const size_t size_gf = ggml_graph_size(gf); + const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 2*ggml_graph_overhead_custom(size_gf, /*grads = */ true); + struct ggml_init_params params = { + /*.mem_size =*/ size_meta, + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + ctx_compute_opt = ggml_init(params); + } + ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits()); + ggml_opt_alloc(opt_ctx, train); + + res->set_inputs(&ubatch); + { + struct ggml_tensor * labels = ggml_opt_labels(opt_ctx); + GGML_ASSERT(labels->ne[1] == n_ubatch); + ggml_set_zero(labels); + const float onef = 1.0f; + for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) { + const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch; + GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]); + ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float)); + } + } + ggml_opt_eval(opt_ctx, result); + if (callback) { + callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start); + } + ggml_free(ctx_compute_opt); + + pos_batch += ubatch.n_tokens; + } while (mstate->next()); + } +} + +void llama_context::opt_epoch( + ggml_opt_dataset_t dataset, + ggml_opt_result_t result_train, + ggml_opt_result_t result_eval, + int64_t idata_split, + ggml_opt_epoch_callback callback_train, + ggml_opt_epoch_callback callback_eval) { + const uint32_t n_ctx = this->n_ctx(); + const uint32_t n_batch = std::min(cparams.n_batch, n_ctx); + const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch); + const int64_t ndata = ggml_opt_dataset_ndata(dataset); + + GGML_ASSERT(idata_split >= 0); + GGML_ASSERT(idata_split <= ndata); + + const uint32_t ubatch_per_ctx = n_ctx / n_ubatch; + + struct llama_batch batch = llama_batch_init(n_batch, 0, 1); + std::vector tokens(n_ctx); + std::vector labels_sparse(n_ctx); + + int64_t idata = 0; + + int64_t t_loop_start = ggml_time_us(); + int64_t ndata_in_loop = idata_split*ubatch_per_ctx; + for (; idata < idata_split; ++idata) { + constexpr bool train = true; + const int64_t idata_in_loop = idata*ubatch_per_ctx; + + ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata); + opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch, + callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start); + } + + t_loop_start = ggml_time_us(); + ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx; + for (; idata < ndata; ++idata) { + constexpr bool train = false; + const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx; + + ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata); + opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch, + callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start); + } + + llama_batch_free(batch); +} + // // interface implementation // @@ -2239,13 +2185,14 @@ llama_context_params llama_context_default_params() { /*.cb_eval_user_data =*/ nullptr, /*.type_k =*/ GGML_TYPE_F16, /*.type_v =*/ GGML_TYPE_F16, - /*.logits_all =*/ false, + /*.abort_callback =*/ nullptr, + /*.abort_callback_data =*/ nullptr, /*.embeddings =*/ false, /*.offload_kqv =*/ true, /*.flash_attn =*/ false, /*.no_perf =*/ true, - /*.abort_callback =*/ nullptr, - /*.abort_callback_data =*/ nullptr, + /*.op_offload =*/ true, + /*.swa_full =*/ true, }; return result; @@ -2320,12 +2267,14 @@ const llama_model * llama_get_model(const llama_context * ctx) { return &ctx->get_model(); } +// deprecated llama_kv_cache * llama_get_kv_self(llama_context * ctx) { - return ctx->get_kv_self(); + return dynamic_cast(ctx->get_memory()); } +// deprecated void llama_kv_self_update(llama_context * ctx) { - ctx->kv_self_update(); + ctx->kv_self_update(false); } enum llama_pooling_type llama_pooling_type(const llama_context * ctx) { @@ -2441,27 +2390,108 @@ int32_t llama_apply_adapter_cvec( } // -// kv cache view +// memory // -llama_kv_cache_view llama_kv_cache_view_init(const llama_context * ctx, int32_t n_seq_max) { - const auto * kv = ctx->get_kv_self(); - if (kv == nullptr) { - LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__); - return {}; - } - - return llama_kv_cache_view_init(*kv, n_seq_max); +llama_memory_t llama_get_memory(const struct llama_context * ctx) { + return ctx->get_memory(); } -void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * view) { - const auto * kv = ctx->get_kv_self(); - if (kv == nullptr) { - LLAMA_LOG_WARN("%s: the context does not have a KV cache\n", __func__); +void llama_memory_clear(llama_memory_t mem, bool data) { + if (!mem) { return; } - llama_kv_cache_view_update(view, kv); + mem->clear(data); +} + +bool llama_memory_seq_rm( + llama_memory_t mem, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1) { + if (!mem) { + return true; + } + + return mem->seq_rm(seq_id, p0, p1); +} + +void llama_memory_seq_cp( + llama_memory_t mem, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1) { + if (!mem) { + return; + } + + mem->seq_cp(seq_id_src, seq_id_dst, p0, p1); +} + +void llama_memory_seq_keep( + llama_memory_t mem, + llama_seq_id seq_id) { + if (!mem) { + return; + } + + mem->seq_keep(seq_id); +} + +void llama_memory_seq_add( + llama_memory_t mem, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta) { + if (!mem) { + return; + } + + mem->seq_add(seq_id, p0, p1, delta); +} + +void llama_memory_seq_div( + llama_memory_t mem, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d) { + if (!mem) { + return; + } + + mem->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_memory_seq_pos_min( + llama_memory_t mem, + llama_seq_id seq_id) { + if (!mem) { + return -1; + } + + return mem->seq_pos_min(seq_id); +} + +llama_pos llama_memory_seq_pos_max( + llama_memory_t mem, + llama_seq_id seq_id) { + if (!mem) { + return -1; + } + + return mem->seq_pos_max(seq_id); +} + +bool llama_memory_can_shift(llama_memory_t mem) { + if (!mem) { + return false; + } + + return mem->get_can_shift(); } // @@ -2469,202 +2499,161 @@ void llama_kv_cache_view_update(const llama_context * ctx, llama_kv_cache_view * // // deprecated -int32_t llama_get_kv_cache_token_count(const llama_context * ctx) { - return llama_kv_self_n_tokens(ctx); -} - int32_t llama_kv_self_n_tokens(const llama_context * ctx) { - const auto * kv = ctx->get_kv_self(); + const auto * kv = llama_get_memory(ctx); if (!kv) { return 0; } - return kv->get_n_tokens(); + int32_t res = 0; + + for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) { + const llama_pos p0 = kv->seq_pos_min(s); + const llama_pos p1 = kv->seq_pos_max(s); + + if (p0 >= 0) { + res += (p1 - p0) + 1; + } + } + + return res; } // deprecated -int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) { - return llama_kv_self_used_cells(ctx); -} - +// note: this is the same as above - will be removed anyway, so it's ok int32_t llama_kv_self_used_cells(const llama_context * ctx) { - const auto * kv = ctx->get_kv_self(); + const auto * kv = llama_get_memory(ctx); if (!kv) { return 0; } - return kv->get_used_cells(); + int32_t res = 0; + + for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) { + const llama_pos p0 = kv->seq_pos_min(s); + const llama_pos p1 = kv->seq_pos_max(s); + + if (p0 >= 0) { + res += (p1 - p0) + 1; + } + } + + return res; } // deprecated -void llama_kv_cache_clear(llama_context * ctx) { - llama_kv_self_clear(ctx); -} - void llama_kv_self_clear(llama_context * ctx) { - auto * kv = ctx->get_kv_self(); + auto * kv = llama_get_memory(ctx); if (!kv) { return; } - kv->clear(); + llama_memory_clear(kv, true); } // deprecated -bool llama_kv_cache_seq_rm( - llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1) { - return llama_kv_self_seq_rm(ctx, seq_id, p0, p1); -} - bool llama_kv_self_seq_rm( llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - auto * kv = ctx->get_kv_self(); + auto * kv = llama_get_memory(ctx); if (!kv) { return true; } - return kv->seq_rm(seq_id, p0, p1); + return llama_memory_seq_rm(kv, seq_id, p0, p1); } // deprecated -void llama_kv_cache_seq_cp( - llama_context * ctx, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1) { - return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1); -} - void llama_kv_self_seq_cp( llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - auto * kv = ctx->get_kv_self(); + auto * kv = llama_get_memory(ctx); if (!kv) { return; } - return kv->seq_cp(seq_id_src, seq_id_dst, p0, p1); + llama_memory_seq_cp(kv, seq_id_src, seq_id_dst, p0, p1); } // deprecated -void llama_kv_cache_seq_keep( - llama_context * ctx, - llama_seq_id seq_id) { - return llama_kv_self_seq_keep(ctx, seq_id); -} - void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { - auto * kv = ctx->get_kv_self(); + auto * kv = llama_get_memory(ctx); if (!kv) { return; } - return kv->seq_keep(seq_id); + llama_memory_seq_keep(kv, seq_id); } // deprecated -void llama_kv_cache_seq_add( - llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta) { - return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta); -} - void llama_kv_self_seq_add( llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { - auto * kv = ctx->get_kv_self(); + auto * kv = llama_get_memory(ctx); if (!kv) { return; } - return kv->seq_add(seq_id, p0, p1, delta); + llama_memory_seq_add(kv, seq_id, p0, p1, delta); } // deprecated -void llama_kv_cache_seq_div( - llama_context * ctx, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - int d) { - return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d); -} - void llama_kv_self_seq_div( llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { - auto * kv = ctx->get_kv_self(); + auto * kv = llama_get_memory(ctx); if (!kv) { return; } - return kv->seq_div(seq_id, p0, p1, d); + llama_memory_seq_div(kv, seq_id, p0, p1, d); } // deprecated -llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { - return llama_kv_self_seq_pos_max(ctx, seq_id); +llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) { + auto * kv = llama_get_memory(ctx); + if (!kv) { + return -1; + } + + return llama_memory_seq_pos_min(kv, seq_id); } +// deprecated llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { - const auto * kv = ctx->get_kv_self(); + auto * kv = llama_get_memory(ctx); if (!kv) { - return 0; + return -1; } - return kv->seq_pos_max(seq_id); + return llama_memory_seq_pos_max(kv, seq_id); } // deprecated -void llama_kv_cache_defrag(llama_context * ctx) { - return llama_kv_self_defrag(ctx); -} - void llama_kv_self_defrag(llama_context * ctx) { - auto * kv = ctx->get_kv_self(); - if (!kv) { - return; - } - - return kv->defrag(); + // force defrag + ctx->kv_self_defrag_sched(); } // deprecated -bool llama_kv_cache_can_shift(const llama_context * ctx) { - return llama_kv_self_can_shift(ctx); -} - bool llama_kv_self_can_shift(const llama_context * ctx) { - const auto * kv = ctx->get_kv_self(); + auto * kv = llama_get_memory(ctx); if (!kv) { return false; } - return kv->get_can_shift(); -} - -// deprecated -void llama_kv_cache_update(llama_context * ctx) { - llama_kv_self_update(ctx); + return llama_memory_can_shift(kv); } // llama state API @@ -2790,7 +2779,7 @@ int32_t llama_decode( llama_context * ctx, llama_batch batch) { const int ret = ctx->decode(batch); - if (ret != 0) { + if (ret != 0 && ret != 1) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } @@ -2829,3 +2818,34 @@ void llama_perf_context_print(const llama_context * ctx) { void llama_perf_context_reset(llama_context * ctx) { ctx->perf_reset(); } + +// +// training +// + +bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata) { + GGML_UNUSED(tensor); + GGML_UNUSED(userdata); + return true; +} + +void llama_opt_init(struct llama_context * ctx, struct llama_model * model, struct llama_opt_params lopt_params) { + ctx->opt_init(model, lopt_params); +} + +void llama_opt_epoch( + struct llama_context * ctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result_train, + ggml_opt_result_t result_eval, + int64_t idata_split, + ggml_opt_epoch_callback callback_train, + ggml_opt_epoch_callback callback_eval) { + ctx->opt_epoch( + dataset, + result_train, + result_eval, + idata_split, + callback_train, + callback_eval); +} diff --git a/src/llama-context.h b/src/llama-context.h index 04facb544..040f03ae4 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -1,22 +1,25 @@ #pragma once #include "llama.h" -#include "llama-batch.h" #include "llama-cparams.h" #include "llama-graph.h" #include "llama-adapter.h" #include "ggml-cpp.h" +#include "ggml-opt.h" #include #include struct llama_model; -struct llama_kv_cache; +class llama_batch_allocr; class llama_io_read_i; class llama_io_write_i; +struct llama_memory_i; +struct llama_memory_state_i; + struct llama_context { // init scheduler and compute buffers, reserve worst-case graphs llama_context( @@ -27,7 +30,12 @@ struct llama_context { void synchronize(); - const llama_model & get_model() const; + const llama_model & get_model() const; + const llama_cparams & get_cparams() const; + + ggml_backend_sched_t get_sched() const; + + ggml_context * get_ctx_compute() const; uint32_t n_ctx() const; uint32_t n_ctx_per_seq() const; @@ -38,10 +46,12 @@ struct llama_context { uint32_t n_threads() const; uint32_t n_threads_batch() const; - llama_kv_cache * get_kv_self(); - const llama_kv_cache * get_kv_self() const; + llama_memory_t get_memory() const; - void kv_self_update(); + // return true of the KV cache was updated + // TODO: remove + bool kv_self_update(bool optimize); + void kv_self_defrag_sched(); enum llama_pooling_type pooling_type() const; @@ -82,8 +92,18 @@ struct llama_context { int32_t il_start, int32_t il_end); - int encode(llama_batch & inp_batch); - int decode(llama_batch & inp_batch); + // process a single ubatch with a specific graph type + // if memory_state is provided, it will be applied first to the context's memory + // ret contains the status of the graph computation + // returns nullptr only if ret != GGML_STATUS_SUCCESS + llm_graph_result_ptr process_ubatch( + const llama_ubatch & ubatch, + llm_graph_type gtype, + llama_memory_state_i * mstate, + ggml_status & ret); + + int encode(const llama_batch & batch_inp); + int decode(const llama_batch & batch_inp); // // state save/load @@ -128,6 +148,32 @@ struct llama_context { llama_perf_context_data perf_get_data() const; void perf_reset(); + // + // training + // + + void opt_init(struct llama_model * model, struct llama_opt_params lopt_params); + + void opt_epoch( + ggml_opt_dataset_t dataset, + ggml_opt_result_t result_train, + ggml_opt_result_t result_eval, + int64_t idata_split, + ggml_opt_epoch_callback callback_train, + ggml_opt_epoch_callback callback_eval); + + void opt_epoch_iter( + ggml_opt_dataset_t dataset, + ggml_opt_result_t result, + const std::vector & tokens, + const std::vector & labels_sparse, + llama_batch & batch, + ggml_opt_epoch_callback callback, + bool train, + int64_t idata_in_loop, + int64_t ndata_in_loop, + int64_t t_loop_start); + private: // // output @@ -135,52 +181,34 @@ private: // Make sure enough space is available for outputs. // Returns max number of outputs for which space was reserved. - int32_t output_reserve(int32_t n_outputs); - - // make the outputs have the same order they had in the user-provided batch - // TODO: maybe remove this - void output_reorder(); + uint32_t output_reserve(int32_t n_outputs); // // graph // +public: int32_t graph_max_nodes() const; // zero-out inputs and create the ctx_compute for the compute graph ggml_cgraph * graph_init(); - llm_graph_result_ptr graph_build( - ggml_context * ctx, - ggml_cgraph * gf, - const llama_ubatch & ubatch, - llm_graph_type gtype); - // returns the result of ggml_backend_sched_graph_compute_async execution - ggml_status graph_compute( - ggml_cgraph * gf, - bool batched); + ggml_status graph_compute(ggml_cgraph * gf, bool batched); + + // reserve a graph with a dummy ubatch of the specified size + ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_state_i * mstate); + +private: + llm_graph_result_ptr graph_build( + ggml_context * ctx, + ggml_cgraph * gf, + const llama_ubatch & ubatch, + llm_graph_type gtype, + const llama_memory_state_i * mstate); llm_graph_cb graph_get_cb() const; - // used by kv_self_update() - ggml_tensor * build_rope_shift( - ggml_context * ctx0, - ggml_tensor * cur, - ggml_tensor * shift, - ggml_tensor * factors, - float freq_base, - float freq_scale, - ggml_backend_buffer * bbuf) const; - - llm_graph_result_ptr build_kv_self_shift( - ggml_context * ctx0, - ggml_cgraph * gf) const; - - llm_graph_result_ptr build_kv_self_defrag( - ggml_context * ctx0, - ggml_cgraph * gf) const; - // TODO: read/write lora adapters and cvec size_t state_write_data(llama_io_write_i & io); size_t state_read_data (llama_io_read_i & io); @@ -197,14 +225,13 @@ private: llama_cparams cparams; llama_adapter_cvec cvec; llama_adapter_loras loras; - llama_sbatch sbatch; llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably - std::unique_ptr kv_self; + std::unique_ptr memory; - // TODO: remove - bool logits_all = false; + // TODO: temporary, until the llama_kv_self_defrag() API is removed + bool memory_force_optimize = false; // decode output (2-dimensional array: [n_outputs][n_vocab]) size_t logits_size = 0; // capacity (of floats) for logits @@ -219,8 +246,10 @@ private: // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE std::map> embd_seq; - int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch - int32_t n_outputs_max = 0; // capacity (of tokens positions) for the output buffers + // reuse the batch_allocr to avoid unnecessary memory allocations + std::unique_ptr batch_allocr; + + uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch std::vector output_ids; // map batch token positions to ids of the logits and embd buffers @@ -231,6 +260,9 @@ private: ggml_context_ptr ctx_compute; + // training + ggml_opt_context_t opt_ctx = nullptr; + ggml_threadpool_t threadpool = nullptr; ggml_threadpool_t threadpool_batch = nullptr; diff --git a/src/llama-cparams.cpp b/src/llama-cparams.cpp index 28369be36..a3e7a37ee 100644 --- a/src/llama-cparams.cpp +++ b/src/llama-cparams.cpp @@ -1 +1,5 @@ #include "llama-cparams.h" + +size_t llama_max_parallel_sequences(void) { + return LLAMA_MAX_SEQ; +} diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 30e550f02..118615d5b 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -4,6 +4,8 @@ #include +#define LLAMA_MAX_SEQ 64 + struct llama_cparams { uint32_t n_ctx; // context size used during inference uint32_t n_batch; @@ -30,6 +32,7 @@ struct llama_cparams { bool flash_attn; bool no_perf; bool warmup; + bool op_offload; enum llama_pooling_type pooling_type; diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 973b47ae0..bed706bb2 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1177,8 +1177,18 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token for (const auto & trigger_pattern : grammar.trigger_patterns) { if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) { grammar.awaiting_trigger = false; - // get from the first match to the end of the string - auto constrained_str = grammar.trigger_buffer.substr(match.position(1)); + // get from the first matched capturing group to the end of the string + size_t start = std::string::npos; + for (auto i = 1u; i < match.size(); i++) { + if (match.length(i) > 0) { + start = match.position(i); + break; + } + } + if (start == std::string::npos) { + start = match.position(0); + } + auto constrained_str = grammar.trigger_buffer.substr(start); // std::string constrained_str(match[1].first, grammar.trigger_buffer.end()); grammar.trigger_buffer.clear(); llama_grammar_accept_str(grammar, constrained_str); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index cd955d63b..337fb5cb0 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -3,39 +3,15 @@ #include "llama-impl.h" #include "llama-batch.h" #include "llama-cparams.h" -#include "llama-kv-cache.h" + +#include "llama-kv-cache-unified.h" +#include "llama-kv-cache-unified-iswa.h" +#include "llama-kv-cache-recurrent.h" #include #include #include -static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { - // TODO move to hparams if a T5 variant appears that uses a different value - const int64_t max_distance = 128; - - if (bidirectional) { - n_buckets >>= 1; - } - - const int64_t max_exact = n_buckets >> 1; - - int32_t relative_position = x - y; - int32_t relative_bucket = 0; - - if (bidirectional) { - relative_bucket += (relative_position > 0) * n_buckets; - relative_position = abs(relative_position); - } else { - relative_position = -std::min(relative_position, 0); - } - - int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact)); - relative_position_if_large = std::min(relative_position_if_large, n_buckets - 1); - relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large); - - return relative_bucket; -} - void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { if (ubatch->token) { const int64_t n_tokens = ubatch->n_tokens; @@ -55,7 +31,21 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) { if (ubatch->pos && pos) { const int64_t n_tokens = ubatch->n_tokens; - ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_token*ggml_element_size(pos)); + if (ubatch->token && n_pos_per_embd == 4) { + // in case we're using M-RoPE with text tokens, convert the 1D positions to 4D + // the 3 first dims are the same, and 4th dim is all 0 + std::vector pos_data(n_tokens*n_pos_per_embd); + // copy the first dimension + for (int i = 0; i < n_tokens; ++i) { + pos_data[ i] = ubatch->pos[i]; + pos_data[ n_tokens + i] = ubatch->pos[i]; + pos_data[2 * n_tokens + i] = ubatch->pos[i]; + pos_data[3 * n_tokens + i] = 0; // 4th dim is 0 + } + ggml_backend_tensor_set(pos, pos_data.data(), 0, pos_data.size()*ggml_element_size(pos)); + } else { + ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(pos)); + } } } @@ -71,7 +61,7 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) { ) * f_attn_temp_scale + 1.0; } - ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*n_pos_per_token*ggml_element_size(attn_scale)); + ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*ggml_element_size(attn_scale)); } } @@ -96,22 +86,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) { void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) { if (pos_bucket) { - const int64_t n_tokens = ubatch->n_tokens; - - GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer)); - GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing - - int32_t * data = (int32_t *) pos_bucket->data; - - const int64_t n_kv = kv_self->n; - - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - for (int i = 0; i < n_kv; ++i) { - data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false); - } - } - } + kv_state->set_input_pos_bucket(pos_bucket, ubatch); } } @@ -164,6 +139,7 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) { std::vector sum(n_tokens, 0); + // TODO: fix indexing [UBATCH_IDX] for (int s = 0; s < n_seqs; ++s) { const llama_seq_id seq_id = ubatch->seq_id[s][0]; @@ -181,6 +157,7 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) { } } + // TODO: fix indexing [UBATCH_IDX] for (int s = 0; s < n_seqs; ++s) { const llama_seq_id seq_id = ubatch->seq_id[s][0]; @@ -205,6 +182,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { uint32_t * data = (uint32_t *) cls->data; memset(cls->data, 0, n_tokens * ggml_element_size(cls)); + // TODO: fix indexing [UBATCH_IDX] for (int s = 0; s < n_seqs; ++s) { const llama_seq_id seq_id = ubatch->seq_id[s][0]; @@ -235,6 +213,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { std::vector last_pos(n_tokens, -1); std::vector last_row(n_tokens, -1); + // TODO: fix indexing [UBATCH_IDX] for (int s = 0; s < n_seqs; ++s) { const llama_seq_id seq_id = ubatch->seq_id[s][0]; @@ -262,7 +241,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); - const int64_t n_kv = kv_self->n; + const int64_t n_kv = kv_state->get_n_kv(); if (s_copy) { GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer)); @@ -270,51 +249,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n for (uint32_t i = 0; i < n_kv; ++i) { - const uint32_t cell_id = i + kv_self->head; - - ////////////////////////////////////////////// - // TODO: this should not mutate the KV cache ! - llama_kv_cell & kv_cell = const_cast(kv_self)->cells[i]; - - // prevent out-of-bound sources - if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) { - kv_cell.src = cell_id; - } - - data[i] = kv_cell.src; - - // TODO: do not mutate the KV cache - // ensure copy only happens once - if (kv_cell.src != (int32_t) cell_id) { - kv_cell.src = cell_id; - } - } - } -} - -void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) { - GGML_UNUSED(ubatch); - - const int64_t n_kv = kv_self->n; - - if (s_mask) { - GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer)); - float * data = (float *) s_mask->data; - - // clear unused states - for (int i = 0; i < n_kv; ++i) { - const uint32_t cell_id = i + kv_self->head; - - ////////////////////////////////////////////// - // TODO: this should not mutate the KV cache ! - llama_kv_cell & kv_cell = const_cast(kv_self)->cells[i]; - - data[i] = (float) (kv_cell.src >= 0); - - // only clear once - if (kv_cell.src < 0) { - kv_cell.src = cell_id; - } + data[i] = kv_state->s_copy(i); } } } @@ -352,6 +287,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { const int32_t ti = s0*n_seq_tokens + i; float f = -INFINITY; + // TODO: fix indexing [UBATCH_IDX] for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) { if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) { if (hparams.use_alibi) { @@ -391,6 +327,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { const int32_t ti = s0*n_seq_tokens + i; float f = -INFINITY; + // TODO: fix indexing [UBATCH_IDX] for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) { if (ubatch->seq_id[s0][s] == seq_id) { if (hparams.use_alibi) { @@ -417,99 +354,18 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { } void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { - if (self_kq_mask || self_kq_mask_swa) { - const int64_t n_kv = kv_self->n; - const int64_t n_tokens = ubatch->n_tokens; - const int64_t n_seq_tokens = ubatch->n_seq_tokens; - const int64_t n_seqs = ubatch->n_seqs; + if (self_kq_mask) { + kv_state->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + } +} - float * data = nullptr; - float * data_swa = nullptr; +void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { + if (self_kq_mask) { + kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + } - if (self_kq_mask) { - GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); - data = (float *) self_kq_mask->data; - } - - if (self_kq_mask_swa) { - GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer)); - data_swa = (float *) self_kq_mask_swa->data; - } - - // Use only the previous KV cells of the correct sequence for each token of the ubatch. - // It's assumed that if a token in the batch has multiple sequences, they are equivalent. - // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch: - // Causal mask: - // xxx------- - // xxxx------ - // xxxxx----- - // Non-causal mask: - // xxxxx----- - // xxxxx----- - // xxxxx----- - // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 - for (int h = 0; h < 1; ++h) { - for (int s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch->seq_id[s][0]; - - for (int j = 0; j < n_seq_tokens; ++j) { - const llama_pos pos = ubatch->pos[s*n_seq_tokens + j]; - for (int i = 0; i < n_kv; ++i) { - float f; - // mask the token if: - if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence - || (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens - ) { - f = -INFINITY; - } else { - if (hparams.use_alibi) { - f = -std::abs(kv_self->cells[i].pos - pos); - } else { - f = 0.0f; - } - } - - if (data) { - data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; - } - - // may need to cut off old tokens for sliding window - // TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask" - if (data_swa) { - if (hparams.n_attn_chunk) { - llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk; - if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) { - f = -INFINITY; - } - } else { - if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) { - f = -INFINITY; - } - } - data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; - } - } - } - } - - // mask padded tokens - if (data) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { - data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; - } - } - } - - // mask padded tokens - if (data_swa) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { - data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; - } - } - } - } + if (self_kq_mask_swa) { + kv_state->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); } } @@ -527,6 +383,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { for (int j = 0; j < n_tokens; ++j) { for (int i = 0; i < n_enc; ++i) { float f = -INFINITY; + // TODO: fix indexing [UBATCH_IDX] for (int s = 0; s < ubatch->n_seq_id[j]; ++s) { const llama_seq_id seq_id = ubatch->seq_id[j][s]; if (cross->seq_ids_enc[i].find(seq_id) != cross->seq_ids_enc[i].end()) { @@ -559,7 +416,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : n_layer (hparams.n_layer), n_rot (hparams.n_rot), n_ctx (cparams.n_ctx), - n_ctx_per_seq (cparams.n_ctx / cparams.n_seq_max), n_head (hparams.n_head()), n_head_kv (hparams.n_head_kv()), n_embd_head_k (hparams.n_embd_head_k), @@ -586,14 +442,14 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : backend_cpu (params.backend_cpu), cvec (params.cvec), loras (params.loras), - memory (params.memory), + mstate (params.mstate), cross (params.cross), cb_func (params.cb), res (std::make_unique()) { } -int64_t llm_graph_context::n_pos_per_token() const { - return arch == LLM_ARCH_QWEN2VL ? 4 : 1; +int64_t llm_graph_context::n_pos_per_embd() const { + return hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1; } void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const { @@ -785,6 +641,7 @@ ggml_tensor * llm_graph_context::build_ffn( { // Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf int64_t split_point = cur->ne[0] / 2; + // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217 ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0)); ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur))); @@ -794,15 +651,33 @@ ggml_tensor * llm_graph_context::build_ffn( cur = ggml_mul(ctx0, x0, x1); cb(cur, "ffn_mul", il); } break; + case LLM_FFN_GEGLU: + { + // Split into two equal parts + int64_t split_point = cur->ne[0] / 2; + // TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217 + ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0)); + ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur))); + + x0 = ggml_gelu(ctx0, x0); + cb(x0, "ffn_gelu", il); + + cur = ggml_mul(ctx0, x0, x1); + cb(cur, "ffn_geglu", il); + } break; } - if (type_gate == LLM_FFN_PAR) { + if (gate && type_gate == LLM_FFN_PAR) { cur = ggml_mul(ctx0, cur, tmp); cb(cur, "ffn_gate_par", il); } if (down) { cur = build_lora_mm(down, cur); + if (arch == LLM_ARCH_GLM4) { + // GLM4 seems to have numerical issues with half-precision accumulators + ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + } } if (down_b) { @@ -900,9 +775,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); if (weight_before_ffn) { - // TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d) - ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens); - repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens] + // repeat cur to [n_embd, n_expert_used, n_tokens] + ggml_tensor * repeated = ggml_repeat_4d(ctx0, cur, n_embd, n_expert_used, n_tokens, 1); cur = ggml_mul(ctx0, repeated, weights); cb(cur, "ffn_moe_weighted", il); } @@ -910,28 +784,35 @@ ggml_tensor * llm_graph_context::build_moe_ffn( ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] cb(up, "ffn_moe_up", il); - ggml_tensor * gate = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] - cb(gate, "ffn_moe_gate", il); + ggml_tensor * experts = nullptr; + if (gate_exps) { + cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(cur, "ffn_moe_gate", il); + } else { + cur = up; + } switch (type_op) { case LLM_FFN_SILU: { - gate = ggml_silu(ctx0, gate); - cb(gate, "ffn_moe_silu", il); + cur = ggml_silu(ctx0, cur); + cb(cur, "ffn_moe_silu", il); } break; case LLM_FFN_GELU: { - gate = ggml_gelu(ctx0, gate); - cb(gate, "ffn_moe_gelu", il); + cur = ggml_gelu(ctx0, cur); + cb(cur, "ffn_moe_gelu", il); } break; default: GGML_ABORT("fatal error"); } - ggml_tensor * par = ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens] - cb(par, "ffn_moe_gate_par", il); + if (gate_exps) { + cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens] + cb(cur, "ffn_moe_gate_par", il); + } - ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens] + experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens] cb(experts, "ffn_moe_down", il); if (!weight_before_ffn) { @@ -974,6 +855,7 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); //cb(inp->tokens, "inp_tokens", -1); ggml_set_input(inp->tokens); + res->t_tokens = inp->tokens; cur = ggml_get_rows(ctx0, tok_embd, inp->tokens); @@ -1014,11 +896,11 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { } ggml_tensor * llm_graph_context::build_inp_pos() const { - auto inp = std::make_unique(n_pos_per_token()); + auto inp = std::make_unique(n_pos_per_embd()); auto & cur = inp->pos; - cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token()); + cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd()); ggml_set_input(cur); res->add_input(std::move(inp)); @@ -1027,11 +909,12 @@ ggml_tensor * llm_graph_context::build_inp_pos() const { } ggml_tensor * llm_graph_context::build_inp_attn_scale() const { - auto inp = std::make_unique(n_pos_per_token(), hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale); + auto inp = std::make_unique(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale); auto & cur = inp->attn_scale; - cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token()); + // this need to be 1x1xN for broadcasting + cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens); ggml_set_input(cur); res->add_input(std::move(inp)); @@ -1079,11 +962,11 @@ ggml_tensor * llm_graph_context::build_inp_cls() const { } ggml_tensor * llm_graph_context::build_inp_s_copy() const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - auto inp = std::make_unique(kv_self); + auto inp = std::make_unique(kv_state); - const auto n_kv = kv_self->n; + const auto n_kv = kv_state->get_n_kv(); auto & cur = inp->s_copy; @@ -1095,23 +978,6 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const { return cur; } -ggml_tensor * llm_graph_context::build_inp_s_mask() const { - const llama_kv_cache_unified * kv_self = static_cast(memory); - - auto inp = std::make_unique(kv_self); - - const auto n_kv = kv_self->n; - - auto & cur = inp->s_mask; - - cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv); - ggml_set_input(cur); - - res->add_input(std::move(inp)); - - return cur; -} - ggml_tensor * llm_graph_context::build_inp_cross_embd() const { auto inp = std::make_unique(cross); @@ -1150,11 +1016,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const { } ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - auto inp = std::make_unique(hparams, kv_self); + auto inp = std::make_unique(hparams, kv_state); - const auto n_kv = kv_self->n; + const auto n_kv = kv_state->get_n_kv(); auto & cur = inp->pos_bucket; @@ -1188,18 +1054,13 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_tensor * v, ggml_tensor * kq_b, ggml_tensor * kq_mask, - bool v_trans, + ggml_tensor * v_mla, float kq_scale) const { - //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - //const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + const bool v_trans = v->nb[1] > v->nb[2]; - //const int64_t n_head = hparams.n_head(il); - //const int64_t n_head_kv = hparams.n_head_kv(il); - - //const auto & n_embd_head_k = hparams.n_embd_head_k; - //const auto & n_embd_head_v = hparams.n_embd_head_v; - - const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0]; + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + k = ggml_permute(ctx0, k, 0, 2, 1, 3); + v = ggml_permute(ctx0, v, 0, 2, 1, 3); const auto n_tokens = q->ne[1]; const auto n_head = q->ne[2]; @@ -1229,7 +1090,23 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); - cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens); + if (v_mla) { +#if 0 + // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens. + // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient. + cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens); + cur = ggml_mul_mat(ctx0, v_mla, cur); +#else + // It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1. + // The permutations are noops and only change how the tensor data is interpreted. + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_mul_mat(ctx0, v_mla, cur); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs. +#endif + } + + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); } else { ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); @@ -1267,9 +1144,14 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); - ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + // for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA + if (v_mla) { + kqv = ggml_mul_mat(ctx0, v_mla, kqv); + } - cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens); + cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + + cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); if (!cparams.offload_kqv) { // all nodes between the KV store and the attention output are run on the CPU @@ -1304,6 +1186,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, + ggml_tensor * v_mla, float kq_scale, int il) const { GGML_UNUSED(n_tokens); @@ -1316,17 +1199,11 @@ ggml_tensor * llm_graph_context::build_attn( const auto & kq_mask = inp->get_kq_mask(); - ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); - //cb(q, "q", il); - - ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3); - //cb(k, "k", il); - - ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); - //cb(k, "v", il); - - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale); + ggml_tensor * q = q_cur; + ggml_tensor * k = k_cur; + ggml_tensor * v = v_cur; + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); cb(cur, "kqv_out", il); if (wo) { @@ -1345,26 +1222,20 @@ ggml_tensor * llm_graph_context::build_attn( } llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - auto inp = std::make_unique(hparams, cparams, kv_self); + auto inp = std::make_unique(hparams, cparams, kv_state); - const auto n_kv = kv_self->n; + { + GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); - inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp->self_kq_mask, "KQ_mask", -1); - ggml_set_input(inp->self_kq_mask); + const auto n_kv = kv_state->get_n_kv(); - inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + //cb(inp->self_kq_mask, "KQ_mask", -1); + ggml_set_input(inp->self_kq_mask); - if (hparams.n_swa_pattern > 1) { - GGML_ASSERT(hparams.n_swa > 0); - - inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1); - ggml_set_input(inp->self_kq_mask_swa); - - inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; } return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp)); @@ -1379,6 +1250,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, + ggml_tensor * v_mla, float kq_scale, int il) const { // these nodes are added to the graph together so that they are not reordered @@ -1387,84 +1259,105 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - const llama_kv_cache_unified * kv_self = static_cast(memory); - const auto & n_ctx = cparams.n_ctx; - - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); - - const auto n_tokens = q_cur->ne[2]; - - const bool v_trans = !cparams.flash_attn; + const auto * kv_state = static_cast(mstate); // store to KV cache { - GGML_ASSERT(!kv_self->recurrent); - - const auto kv_head = kv_self->head; - - GGML_ASSERT(kv_self->size == n_ctx); - - ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa)*kv_head); - //cb(k_cache_view, "k_cache_view", il); - - // note: storing RoPE-ed version of K in the KV cache - ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view)); - - v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens); - - ggml_tensor * v_cache_view = nullptr; - - if (!v_trans) { - v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head); - } else { - // note: the V cache is transposed when not using flash attention - v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa, - ( n_ctx)*ggml_element_size(kv_self->v_l[il]), - (kv_head)*ggml_element_size(kv_self->v_l[il])); - - v_cur = ggml_transpose(ctx0, v_cur); - } - //cb(v_cache_view, "v_cache_view", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view)); + ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); + ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); } + const auto & kq_mask = inp->get_kq_mask(); + + ggml_tensor * q = q_cur; + ggml_tensor * k = kv_state->get_k(ctx0, il); + ggml_tensor * v = kv_state->get_v(ctx0, il); + + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); + cb(cur, "kqv_out", il); + + if (wo) { + cur = build_lora_mm(wo, cur); + if (arch == LLM_ARCH_GLM4) { + // GLM4 seems to have numerical issues with half-precision accumulators + ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + } + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + +llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { + const auto * kv_state = static_cast(mstate); + + auto inp = std::make_unique(hparams, cparams, kv_state); + + { + const auto n_kv = kv_state->get_base()->get_n_kv(); + + inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + //cb(inp->self_kq_mask, "KQ_mask", -1); + ggml_set_input(inp->self_kq_mask); + + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + } + + { + GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA"); + + const auto n_kv = kv_state->get_swa()->get_n_kv(); + + inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1); + ggml_set_input(inp->self_kq_mask_swa); + + inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; + } + + return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_kv_unified_iswa * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * v_mla, + float kq_scale, + int il) const { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, k_cur); + ggml_build_forward_expand(gf, v_cur); + + const auto * kv_state_iswa = static_cast(mstate); + const bool is_swa = hparams.is_swa(il); + const auto * kv_state = is_swa ? kv_state_iswa->get_swa() : kv_state_iswa->get_base(); + + // store to KV cache + { + ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); + ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); + } + const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); - const auto n_kv = kv_self->n; + ggml_tensor * q = q_cur; + ggml_tensor * k = kv_state->get_k(ctx0, il); + ggml_tensor * v = kv_state->get_v(ctx0, il); - const int64_t n_head_kv = hparams.n_head_kv(il); - - const auto & n_embd_head_k = hparams.n_embd_head_k; - const auto & n_embd_head_v = hparams.n_embd_head_v; - - ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); - //cb(q, "q", il); - - ggml_tensor * k = - ggml_view_3d(ctx0, kv_self->k_l[il], - n_embd_head_k, n_kv, n_head_kv, - ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa), - ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k), - 0); - //cb(k, "k", il); - - ggml_tensor * v = !v_trans ? - ggml_view_3d(ctx0, kv_self->v_l[il], - n_embd_head_v, n_kv, n_head_kv, - ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa), - ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v), - 0) : - ggml_view_3d(ctx0, kv_self->v_l[il], - n_kv, n_embd_head_v, n_head_kv, - ggml_element_size(kv_self->v_l[il])*n_ctx, - ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v, - 0); - - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale); + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); cb(cur, "kqv_out", il); if (wo) { @@ -1504,6 +1397,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, + ggml_tensor * v_mla, float kq_scale, int il) const { // these nodes are added to the graph together so that they are not reordered @@ -1514,17 +1408,11 @@ ggml_tensor * llm_graph_context::build_attn( const auto & kq_mask = inp->get_kq_mask_cross(); - ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3); - //cb(q, "q", il); - - ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3); - //cb(k, "k", il); - - ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3); - //cb(k, "v", il); - - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale); + ggml_tensor * q = q_cur; + ggml_tensor * k = k_cur; + ggml_tensor * v = v_cur; + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); cb(cur, "kqv_out", il); if (wo) { @@ -1542,55 +1430,65 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } -ggml_tensor * llm_graph_context::build_copy_mask_state( +ggml_tensor * llm_graph_context::build_recurrent_state( ggml_cgraph * gf, ggml_tensor * s, ggml_tensor * state_copy, - ggml_tensor * state_mask, - int32_t n_state, - int32_t n_seqs) const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + int32_t state_size, + int32_t n_seqs, + bool avoid_copies) const { + const auto * kv_state = static_cast(mstate); - const auto n_kv = kv_self->n; - const auto kv_head = kv_self->head; + const auto n_kv = kv_state->get_n_kv(); + const auto kv_head = kv_state->get_head(); + const auto rs_zero = kv_state->get_rs_z(); - ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size); + ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size()); - // copy states - // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv - // this shrinks the tensors's ne[1] to n_kv - states = ggml_get_rows(ctx0, states, state_copy); + // Clear a single state which will then be copied to the other cleared states. + // Note that this is a no-op when the view is zero-sized. + ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); + ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); - // clear states of sequences which are starting at the beginning of this batch - // FIXME: zero-out NANs? - states = ggml_mul(ctx0, states, state_mask); + ggml_tensor * output_states; - // copy states which won't be changed further (between n_seqs and n_kv) + if (!avoid_copies) { + // copy states + // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv + // {state_size, kv_size} -> {state_size, n_seqs} + output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); + ggml_build_forward_expand(gf, output_states); + } else { + // FIXME: make the gathering operation happen before the copy below + // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?) + output_states = states; + } + + // copy extra states which won't be changed further (between n_seqs and n_kv) + ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); ggml_build_forward_expand(gf, ggml_cpy(ctx0, - ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*ggml_element_size(states)), - ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s)))); + states_extra, + ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s)))); - // the part of the states that will be used and modified - return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0); + return output_states; } ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_cgraph * gf, ggml_tensor * state_copy, - ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); const auto token_shift_count = hparams.token_shift_count; const int64_t n_seqs = ubatch.n_seqs; - ggml_tensor * token_shift_all = kv_self->k_l[il]; + ggml_tensor * token_shift_all = kv_state->get_k_l(il); - ggml_tensor * token_shift = build_copy_mask_state( - gf, token_shift_all, state_copy, state_mask, + ggml_tensor * token_shift = build_recurrent_state( + gf, token_shift_all, state_copy, hparams.n_embd_k_s(), n_seqs); token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs); @@ -1602,19 +1500,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( ggml_tensor * token_shift, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); const auto token_shift_count = hparams.token_shift_count; const auto n_embd = hparams.n_embd; const int64_t n_seqs = ubatch.n_seqs; - const auto kv_head = kv_self->head; + const auto kv_head = kv_state->get_head(); return ggml_cpy( ctx0, ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0), - ggml_view_1d(ctx0, kv_self->k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self->k_l[il])) + ggml_view_1d(ctx0, kv_state->get_k_l(il), hparams.n_embd_k_s()*n_seqs, hparams.n_embd_k_s()*kv_head*ggml_element_size(kv_state->get_k_l(il))) ); } @@ -1665,20 +1563,32 @@ void llm_graph_context::build_pooling( ggml_tensor * inp_cls = build_inp_cls(); inp = ggml_get_rows(ctx0, inp, inp_cls); - // classification head - // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566 - GGML_ASSERT(cls != nullptr); - GGML_ASSERT(cls_b != nullptr); + if (cls) { + // classification head + // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566 + cur = ggml_mul_mat(ctx0, cls, inp); + if (cls_b) { + cur = ggml_add(ctx0, cur, cls_b); + } + cur = ggml_tanh(ctx0, cur); - cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b); - cur = ggml_tanh(ctx0, cur); - - // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en - // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896 - if (cls_out) { - GGML_ASSERT(cls_out_b != nullptr); - - cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b); + // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en + // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896 + if (cls_out) { + cur = ggml_mul_mat(ctx0, cls_out, cur); + if (cls_out_b) { + cur = ggml_add(ctx0, cur, cls_out_b); + } + } + } else if (cls_out) { + // Single layer classification head (direct projection) + // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476 + cur = ggml_mul_mat(ctx0, cls_out, inp); + if (cls_out_b) { + cur = ggml_add(ctx0, cur, cls_out_b); + } + } else { + GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b"); } } break; default: @@ -1693,3 +1603,29 @@ void llm_graph_context::build_pooling( ggml_build_forward_expand(gf, cur); } +int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { + // TODO move to hparams if a T5 variant appears that uses a different value + const int64_t max_distance = 128; + + if (bidirectional) { + n_buckets >>= 1; + } + + const int64_t max_exact = n_buckets >> 1; + + int32_t relative_position = x - y; + int32_t relative_bucket = 0; + + if (bidirectional) { + relative_bucket += (relative_position > 0) * n_buckets; + relative_position = abs(relative_position); + } else { + relative_position = -std::min(relative_position, 0); + } + + int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact)); + relative_position_if_large = std::min(relative_position_if_large, n_buckets - 1); + relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large); + + return relative_bucket; +} diff --git a/src/llama-graph.h b/src/llama-graph.h index 5b6618f9e..87813119b 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -17,8 +17,11 @@ struct ggml_tensor; struct llama_ubatch; struct llama_cparams; -class llama_memory_i; -class llama_kv_cache_unified; +struct llama_memory_state_i; + +class llama_kv_cache_unified_state; +class llama_kv_cache_unified_iswa_state; +class llama_kv_cache_recurrent_state; // certain models (typically multi-modal) can produce different types of graphs enum llm_graph_type { @@ -33,6 +36,7 @@ enum llm_ffn_op_type { LLM_FFN_RELU, LLM_FFN_RELU_SQR, LLM_FFN_SWIGLU, + LLM_FFN_GEGLU, }; enum llm_ffn_gate_type { @@ -90,29 +94,27 @@ public: class llm_graph_input_pos : public llm_graph_input_i { public: - llm_graph_input_pos(int64_t n_pos_per_token) : n_pos_per_token(n_pos_per_token) {} + llm_graph_input_pos(int64_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {} virtual ~llm_graph_input_pos() = default; void set_input(const llama_ubatch * ubatch) override; ggml_tensor * pos = nullptr; // I32 [n_batch] - const int64_t n_pos_per_token = 1; + const int64_t n_pos_per_embd = 1; }; // temperature tuning, used by llama4 class llm_graph_input_attn_temp : public llm_graph_input_i { public: - llm_graph_input_attn_temp(int64_t n_pos_per_token, uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale) - : n_pos_per_token(n_pos_per_token), n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {} + llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale) + : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {} virtual ~llm_graph_input_attn_temp() = default; void set_input(const llama_ubatch * ubatch) override; ggml_tensor * attn_scale = nullptr; // F32 [n_batch] - const int64_t n_pos_per_token = 1; - const uint32_t n_attn_temp_floor_scale; const float f_attn_temp_scale; }; @@ -133,7 +135,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i { public: llm_graph_input_pos_bucket_kv( const llama_hparams & hparams, - const llama_kv_cache_unified * kv_self) : hparams(hparams), kv_self(kv_self) {} + const llama_kv_cache_unified_state * kv_state) : hparams(hparams), kv_state(kv_state) {} virtual ~llm_graph_input_pos_bucket_kv() = default; void set_input(const llama_ubatch * ubatch) override; @@ -141,7 +143,7 @@ public: ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch] const llama_hparams & hparams; - const llama_kv_cache_unified * kv_self; + const llama_kv_cache_unified_state * kv_state; }; class llm_graph_input_out_ids : public llm_graph_input_i { @@ -188,26 +190,14 @@ public: class llm_graph_input_s_copy : public llm_graph_input_i { public: - llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} + llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {} virtual ~llm_graph_input_s_copy() = default; void set_input(const llama_ubatch * ubatch) override; ggml_tensor * s_copy; // I32 [kv_size] - const llama_kv_cache_unified * kv_self; -}; - -class llm_graph_input_s_mask : public llm_graph_input_i { -public: - llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} - virtual ~llm_graph_input_s_mask() = default; - - void set_input(const llama_ubatch * ubatch) override; - - ggml_tensor * s_mask; // F32 [1, n_kv] - - const llama_kv_cache_unified * kv_self; + const llama_kv_cache_recurrent_state * kv_state; }; class llm_graph_input_cross_embd : public llm_graph_input_i { @@ -247,15 +237,40 @@ public: llm_graph_input_attn_kv_unified( const llama_hparams & hparams, const llama_cparams & cparams, - const llama_kv_cache_unified * kv_self) : + const llama_kv_cache_unified_state * kv_state) : hparams(hparams), cparams(cparams), - kv_self(kv_self) { + kv_state(kv_state) { } ~llm_graph_input_attn_kv_unified() = default; void set_input(const llama_ubatch * ubatch) override; + ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } + + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] + + const llama_hparams & hparams; + const llama_cparams & cparams; + + const llama_kv_cache_unified_state * kv_state; +}; + +class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { +public: + llm_graph_input_attn_kv_unified_iswa( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_unified_iswa_state * kv_state) : + hparams(hparams), + cparams(cparams), + kv_state(kv_state) { + } + ~llm_graph_input_attn_kv_unified_iswa() = default; + + void set_input(const llama_ubatch * ubatch) override; + ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } @@ -267,7 +282,7 @@ public: const llama_hparams & hparams; const llama_cparams & cparams; - const llama_kv_cache_unified * kv_self; + const llama_kv_cache_unified_iswa_state * kv_state; }; class llm_graph_input_attn_cross : public llm_graph_input_i { @@ -299,6 +314,7 @@ class llm_graph_result_i { public: virtual ~llm_graph_result_i() = default; + virtual ggml_tensor * get_tokens() = 0; virtual ggml_tensor * get_logits() = 0; virtual ggml_tensor * get_embd() = 0; virtual ggml_tensor * get_embd_pooled() = 0; @@ -313,6 +329,7 @@ class llm_graph_result : public llm_graph_result_i { public: virtual ~llm_graph_result() = default; + ggml_tensor * get_tokens() override { return t_tokens; } ggml_tensor * get_logits() override { return t_logits; } ggml_tensor * get_embd() override { return t_embd; } ggml_tensor * get_embd_pooled() override { return t_embd_pooled; } @@ -329,6 +346,7 @@ public: } // important graph nodes + ggml_tensor * t_tokens = nullptr; ggml_tensor * t_logits = nullptr; ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; @@ -352,15 +370,15 @@ struct llm_graph_params { const llama_cparams & cparams; const llama_ubatch & ubatch; - ggml_backend_sched * sched; - ggml_backend * backend_cpu; + ggml_backend_sched_t sched; + ggml_backend_t backend_cpu; - const llama_adapter_cvec * cvec; - const llama_adapter_loras * loras; - const llama_memory_i * memory; - const llama_cross * cross; + const llama_adapter_cvec * cvec; + const llama_adapter_loras * loras; + const llama_memory_state_i * mstate; + const llama_cross * cross; - int32_t n_outputs; + uint32_t n_outputs; const llm_graph_cb & cb; }; @@ -376,7 +394,6 @@ struct llm_graph_context { const int64_t n_layer; const int64_t n_rot; const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train) - const int64_t n_ctx_per_seq; const int64_t n_head; const int64_t n_head_kv; const int64_t n_embd_head_k; @@ -395,8 +412,8 @@ struct llm_graph_context { const float norm_eps; const float norm_rms_eps; - const int32_t n_tokens; - const int32_t n_outputs; + const int64_t n_tokens; + const int64_t n_outputs; const int32_t n_ctx_orig; // yarn const enum llama_pooling_type pooling_type; @@ -404,14 +421,14 @@ struct llm_graph_context { ggml_context * ctx0 = nullptr; - ggml_backend_sched * sched; + ggml_backend_sched_t sched; - ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove? + ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove? - const llama_adapter_cvec * cvec; - const llama_adapter_loras * loras; - const llama_memory_i * memory; - const llama_cross * cross; + const llama_adapter_cvec * cvec; + const llama_adapter_loras * loras; + const llama_memory_state_i * mstate; + const llama_cross * cross; const llm_graph_cb & cb_func; @@ -419,7 +436,7 @@ struct llm_graph_context { llm_graph_context(const llm_graph_params & params); - int64_t n_pos_per_token() const; + int64_t n_pos_per_embd() const; void cb(ggml_tensor * cur, const char * name, int il) const; @@ -492,7 +509,6 @@ struct llm_graph_context { ggml_tensor * build_inp_mean() const; ggml_tensor * build_inp_cls() const; ggml_tensor * build_inp_s_copy() const; - ggml_tensor * build_inp_s_mask() const; ggml_tensor * build_inp_cross_embd() const; ggml_tensor * build_inp_pos_bucket_enc() const; @@ -505,12 +521,12 @@ struct llm_graph_context { ggml_tensor * build_attn_mha( ggml_cgraph * gf, - ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q] - ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k] - ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false) + ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false) ggml_tensor * kq_b, ggml_tensor * kq_mask, - bool v_trans, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale) const; llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const; @@ -524,6 +540,7 @@ struct llm_graph_context { ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale, int il) const; @@ -538,6 +555,22 @@ struct llm_graph_context { ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] + float kq_scale, + int il) const; + + llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_kv_unified_iswa * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale, int il) const; @@ -552,6 +585,7 @@ struct llm_graph_context { ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale, int il) const; @@ -559,18 +593,17 @@ struct llm_graph_context { // recurrent // - ggml_tensor * build_copy_mask_state( + ggml_tensor * build_recurrent_state( ggml_cgraph * gf, ggml_tensor * s, ggml_tensor * state_copy, - ggml_tensor * state_mask, - int32_t n_state, - int32_t n_seqs) const; + int32_t state_size, + int32_t n_seqs, + bool avoid_copies = false) const; ggml_tensor * build_rwkv_token_shift_load( ggml_cgraph * gf, ggml_tensor * state_copy, - ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const; @@ -590,3 +623,6 @@ struct llm_graph_context { ggml_tensor * cls_out, ggml_tensor * cls_out_b) const; }; + +// TODO: better name +int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional); diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 90dfe7a7f..1499eb08a 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -2,6 +2,22 @@ #include "ggml.h" +void llama_hparams::set_swa_pattern(uint32_t n_pattern) { + for (uint32_t il = 0; il < n_layer; ++il) { + swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); + } +} + +bool llama_hparams::is_swa_any() const { + for (uint32_t il = 0; il < n_layer; ++il) { + if (swa_layers[il]) { + return true; + } + } + + return false; +} + uint32_t llama_hparams::n_head(uint32_t il) const { if (il < n_layer) { return n_head_arr[il]; @@ -72,7 +88,7 @@ uint32_t llama_hparams::n_embd_v_s() const { bool llama_hparams::is_swa(uint32_t il) const { if (il < n_layer) { - return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1); + return swa_layers[il]; } GGML_ABORT("fatal error"); diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 4e0b57190..b2bcb8b01 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -14,6 +14,12 @@ enum llama_expert_gating_func_type { LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2, }; +enum llama_swa_type { + LLAMA_SWA_TYPE_NONE = 0, + LLAMA_SWA_TYPE_STANDARD = 1, + LLAMA_SWA_TYPE_CHUNKED = 2, +}; + struct llama_hparams_posnet { uint32_t n_embd; uint32_t n_layer; @@ -35,14 +41,16 @@ struct llama_hparams { uint32_t n_embd_features = 0; uint32_t n_layer; uint32_t n_rot; - uint32_t n_swa = 0; // sliding window attention (SWA) - uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head uint32_t n_expert = 0; uint32_t n_expert_used = 0; uint32_t n_rel_attn_bkts = 0; + // note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA + uint32_t n_embd_head_k_mla = 0; + uint32_t n_embd_head_v_mla = 0; + // for WavTokenizer struct llama_hparams_posnet posnet; struct llama_hparams_convnext convnext; @@ -62,6 +70,7 @@ struct llama_hparams { float expert_weights_scale = 0.0; bool expert_weights_norm = false; uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE; + uint32_t moe_every_n_layers = 0; float f_norm_eps; float f_norm_rms_eps; @@ -91,6 +100,15 @@ struct llama_hparams { std::array rope_sections; + // Sliding Window Attention (SWA) + llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; + // the size of the sliding window (0 - no SWA) + uint32_t n_swa = 0; + // if swa_layers[il] == true, then layer il is SWA + // if swa_layers[il] == false, then layer il is dense (i.e. non-SWA) + // by default, all layers are dense + std::array swa_layers; + // for State Space Models uint32_t ssm_d_conv = 0; uint32_t ssm_d_inner = 0; @@ -111,11 +129,13 @@ struct llama_hparams { bool causal_attn = true; bool use_alibi = false; bool attn_soft_cap = false; + bool use_kq_norm = true; + // for Classifiers + uint32_t n_cls_out = 1; + + // llama4 uint32_t n_moe_layer_step = 0; - bool use_kq_norm = true; - uint32_t n_attn_chunk = 0; - // values below seems to be fixed on llama4 uint32_t n_no_rope_layer_step = 4; uint32_t n_attn_temp_floor_scale = 8192; float f_attn_temp_scale = 0.1; @@ -128,6 +148,23 @@ struct llama_hparams { enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; + // this value n_pattern means that every nth layer is dense (i.e. non-SWA) + // note that if n_pattern == 0, all layers are SWA + // if n_pattern == 1, all layers are dense + // example: n_pattern = 3 + // il == 0: swa + // il == 1: swa + // il == 2: dense + // il == 3: swa + // il == 4: swa + // il == 5: dense + // il == 6: swa + // etc ... + void set_swa_pattern(uint32_t n_pattern); + + // return true if one of the layers is SWA + bool is_swa_any() const; + uint32_t n_head(uint32_t il = 0) const; uint32_t n_head_kv(uint32_t il = 0) const; diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-kv-cache-recurrent.cpp new file mode 100644 index 000000000..8f6f120f6 --- /dev/null +++ b/src/llama-kv-cache-recurrent.cpp @@ -0,0 +1,1115 @@ +#include "llama-kv-cache-recurrent.h" + +#include "llama-impl.h" +#include "llama-io.h" +#include "llama-batch.h" +#include "llama-model.h" + +#include +#include +#include +#include +#include + +// +// llama_kv_cache_recurrent +// + +llama_kv_cache_recurrent::llama_kv_cache_recurrent( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { + const int32_t n_layer = hparams.n_layer; + + LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n", + __func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer); + + head = 0; + size = kv_size; + used = 0; + + cells.clear(); + cells.resize(kv_size); + + // create a context for each buffer type + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return nullptr; + } + + ctx_map[buft] = ctx; + ctxs.emplace_back(ctx); + + return ctx; + } + + return it->second; + }; + + k_l.reserve(n_layer); + v_l.reserve(n_layer); + + for (int i = 0; i < n_layer; i++) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + + const char * dev_name = "CPU"; + + ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); + + if (offload) { + auto * dev = model.dev_layer(i); + buft = ggml_backend_dev_buffer_type(dev); + + dev_name = ggml_backend_dev_name(dev); + } + + LLAMA_LOG_DEBUG("%s, layer %3d: dev = %s\n", __func__, i, dev_name); + + ggml_context * ctx = ctx_for_buft(buft); + if (!ctx) { + throw std::runtime_error("failed to create ggml context for kv cache"); + } + + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_format_name(k, "cache_k_l%d", i); + ggml_format_name(v, "cache_v_l%d", i); + k_l.push_back(k); + v_l.push_back(v); + } + + // allocate tensors and initialize the buffers to avoid NaNs in the padding + for (auto it : ctx_map) { + auto * buft = it.first; + auto * ctx = it.second; + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + throw std::runtime_error("failed to allocate buffer for kv cache"); + } + ggml_backend_buffer_clear(buf, 0); + LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + bufs.emplace_back(buf); + } + + { + const size_t memory_size_k = size_k_bytes(); + const size_t memory_size_v = size_v_bytes(); + + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + } +} + +void llama_kv_cache_recurrent::clear(bool data) { + for (int32_t i = 0; i < (int32_t) size; ++i) { + cells[i].pos = -1; + cells[i].seq_id.clear(); + cells[i].src = -1; + cells[i].tail = -1; + } + + head = 0; + used = 0; + + if (data) { + for (auto & buf : bufs) { + ggml_backend_buffer_clear(buf.get(), 0); + } + } +} + +bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + uint32_t new_head = size; + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // models like Mamba or RWKV can't have a state partially erased + if (seq_id >= (int64_t) size) { + // could be fatal + return false; + } + if (0 <= seq_id) { + int32_t & tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + const kv_cell & cell = cells[tail_id]; + // partial intersection is invalid + if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { + return false; + } + // invalidate tails which will be cleared + if (p0 <= cell.pos && cell.pos < p1) { + tail_id = -1; + } + } + } else { + // seq_id is negative, then the range should include everything or nothing + if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { + return false; + } + } + + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].pos >= p0 && cells[i].pos < p1) { + if (seq_id < 0) { + cells[i].seq_id.clear(); + } else if (cells[i].has_seq_id(seq_id)) { + cells[i].seq_id.erase(seq_id); + } else { + continue; + } + if (cells[i].is_empty()) { + // keep count of the number of used cells + if (cells[i].pos >= 0) { + used--; + } + cells[i].pos = -1; + cells[i].src = -1; + if (new_head == size) { + new_head = i; + } + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != size && new_head < head) { + head = new_head; + } + + return true; +} + +void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + if (seq_id_src == seq_id_dst) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) { + kv_cell & tail_src = cells[seq_id_src]; + kv_cell & tail_dst = cells[seq_id_dst]; + if (tail_dst.tail >= 0) { + // clear destination seq_id if it wasn't empty + kv_cell & cell_dst = cells[tail_dst.tail]; + + cell_dst.seq_id.erase(seq_id_dst); + tail_dst.tail = -1; + if (cell_dst.seq_id.empty()) { + cell_dst.pos = -1; + cell_dst.src = -1; + used -= 1; + } + } + if (tail_src.tail >= 0) { + kv_cell & cell_src = cells[tail_src.tail]; + + cell_src.seq_id.insert(seq_id_dst); + tail_dst.tail = tail_src.tail; + } + } +} + +void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) { + uint32_t new_head = size; + + for (uint32_t i = 0; i < size; ++i) { + if ((llama_seq_id) i != seq_id) { + cells[i].tail = -1; + } + + if (!cells[i].has_seq_id(seq_id)) { + if (cells[i].pos >= 0) { + used--; + } + + cells[i].pos = -1; + cells[i].src = -1; + cells[i].seq_id.clear(); + + if (new_head == size){ + new_head = i; + } + } else { + cells[i].seq_id.clear(); + cells[i].seq_id.insert(seq_id); + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != size && new_head < head) { + head = new_head; + } +} + +void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + if (shift == 0) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over the + if (p0 == p1) { + return; + } + + // for Mamba-like or RWKV models, only the pos needs to be shifted + if (0 <= seq_id && seq_id < (int64_t) size) { + const int32_t tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + kv_cell & cell = cells[tail_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos += shift; + } + } + } +} + +void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + if (d == 1) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over the cache. + if (p0 == p1) { + return; + } + + // for Mamba-like or RWKV models, only the pos needs to be changed + if (0 <= seq_id && seq_id < (int64_t) size) { + const int32_t tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + kv_cell & cell = cells[tail_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos /= d; + } + } + } +} + +llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const { + llama_pos result = std::numeric_limits::max(); + + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id)) { + result = std::min(result, cells[i].pos); + } + } + + if (result == std::numeric_limits::max()) { + result = -1; + } + + return result; +} + +llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const { + llama_pos result = -1; + + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id)) { + result = std::max(result, cells[i].pos); + } + } + + return result; +} + +llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) { + auto sbatch = llama_sbatch(batch, hparams.n_embd, false); + + std::vector ubatches; + + while (sbatch.n_tokens > 0) { + llama_ubatch ubatch; + + if (embd_all) { + // if all tokens are output, split by sequence + ubatch = sbatch.split_seq(n_ubatch); + } else { + ubatch = sbatch.split_equal(n_ubatch); + } + + ubatches.push_back(ubatch); + } + + if (!prepare(ubatches)) { + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(ubatches)); +} + +llama_memory_state_ptr llama_kv_cache_recurrent::init_full() { + return std::make_unique(LLAMA_MEMORY_STATUS_SUCCESS, this); +} + +llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) { + GGML_UNUSED(lctx); + GGML_UNUSED(optimize); + + return std::make_unique(LLAMA_MEMORY_STATUS_NO_UPDATE); +} + +bool llama_kv_cache_recurrent::prepare(const std::vector & ubatches) { + // simply remember the full state because it is very small for this type of cache + // TODO: optimize + auto org_cells = cells; + auto org_used = used; + auto org_head = head; + + bool success = true; + + for (const auto & ubatch : ubatches) { + if (!find_slot(ubatch)) { + success = false; + break; + } + } + + // restore the original state + cells = std::move(org_cells); + used = org_used; + head = org_head; + + return success; +} + +bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) { + const uint32_t n_seqs = ubatch.n_seqs; + + const uint32_t n_seq_tokens = ubatch.n_seq_tokens; + + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (head > used + 2*n_seqs) { + head = 0; + } + + // For recurrent state architectures (like Mamba or RWKV), + // each cache cell can store the state for a whole sequence. + // A slot should be always be contiguous. + + // can only process batches with an equal number of new tokens in each sequence + GGML_ASSERT(ubatch.equal_seqs); + + int32_t min = size - 1; + int32_t max = 0; + + // everything should fit if all seq_ids are smaller than the max + for (uint32_t s = 0; s < n_seqs; ++s) { + const uint32_t n_seq_id = ubatch.n_seq_id[s]; + for (uint32_t j = 0; j < n_seq_id; ++j) { + const llama_seq_id seq_id = ubatch.seq_id[s][j]; + + if (seq_id < 0 || (uint32_t) seq_id >= size) { + // too big seq_id + // TODO: would it be possible to resize the cache instead? + LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max); + return false; + } + if (j > 0) { + kv_cell & seq = cells[seq_id]; + if (seq.tail >= 0) { + kv_cell & cell = cells[seq.tail]; + // clear cells from seq_ids that become shared + // (should not normally happen, but let's handle it anyway) + cell.seq_id.erase(seq_id); + seq.tail = -1; + if (cell.seq_id.empty()) { + cell.pos = -1; + cell.src = -1; + used -= 1; + } + } + } + } + } + +#ifndef NDEBUG + { + std::vector tails_verif; + tails_verif.assign(size, -1); + for (uint32_t i = 0; i < size; ++i) { + kv_cell & cell = cells[i]; + for (llama_seq_id seq_id : cell.seq_id) { + if (tails_verif[seq_id] != -1) { + LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]); + } + tails_verif[seq_id] = i; + } + } + for (uint32_t i = 0; i < size; ++i) { + if (tails_verif[i] != cells[i].tail) { + LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]); + } + } + } +#endif + + // find next empty cell + uint32_t next_empty_cell = head; + + for (uint32_t i = 0; i < size; ++i) { + if (next_empty_cell >= size) { next_empty_cell -= size; } + kv_cell & cell = cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; + } + + // find usable cell range + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + kv_cell & seq_meta = cells[seq_id]; + bool has_cell = false; + if (seq_meta.tail >= 0) { + kv_cell & cell = cells[seq_meta.tail]; + GGML_ASSERT(cell.has_seq_id(seq_id)); + // does this seq_id "own" the cell? + if (cell.seq_id.size() == 1) { has_cell = true; } + } + if (!has_cell) { + kv_cell & empty_cell = cells[next_empty_cell]; + GGML_ASSERT(empty_cell.is_empty()); + // copy old tail into the empty cell + if (seq_meta.tail >= 0) { + kv_cell & orig_cell = cells[seq_meta.tail]; + empty_cell.pos = orig_cell.pos; + empty_cell.src = orig_cell.src; + orig_cell.seq_id.erase(seq_id); + empty_cell.seq_id.insert(seq_id); // will be overwritten + GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id + } + seq_meta.tail = next_empty_cell; + // find next empty cell + if (s + 1 < n_seqs) { + for (uint32_t i = 0; i < size; ++i) { + next_empty_cell += 1; + if (next_empty_cell >= size) { next_empty_cell -= size; } + kv_cell & cell = cells[next_empty_cell]; + if (cell.is_empty()) { break; } + } + } + } + if (min > seq_meta.tail) { min = seq_meta.tail; } + if (max < seq_meta.tail) { max = seq_meta.tail; } + } + + // gather and re-order + for (uint32_t s = 0; s < n_seqs; ++s) { + const int32_t dst_id = s + min; + const int32_t src_id = cells[ubatch.seq_id[s][0]].tail; + if (dst_id != src_id) { + kv_cell & dst_cell = cells[dst_id]; + kv_cell & src_cell = cells[src_id]; + + std::swap(dst_cell.pos, src_cell.pos); + std::swap(dst_cell.src, src_cell.src); + std::swap(dst_cell.seq_id, src_cell.seq_id); + + // swap tails + for (uint32_t i = 0; i < size; ++i) { + int32_t & tail = cells[i].tail; + if (tail == src_id) { + tail = dst_id; + } else if (tail == dst_id) { + tail = src_id; + } + } + } + } + + // update the pos of the used seqs + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; + const int32_t cell_id = s + min; + kv_cell & cell = cells[cell_id]; + + if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { + // What should happen when the pos backtracks or skips a value? + // Clearing the state mid-batch would require special-casing which isn't done. + LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", + __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens); + } + cell.pos = last_pos; + cell.seq_id.clear(); + for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) { + const llama_seq_id seq_id = ubatch.seq_id[s][j]; + cell.seq_id.insert(seq_id); + cells[seq_id].tail = cell_id; + } + } + + // Find first cell without src refs, to use as the zero-ed state + { + // TODO: bake-in src refcounts in the cell metadata + std::vector refcounts(size, 0); + for (size_t i = 0; i < size; ++i) { + const int32_t src = cells[i].src; + if (src >= 0) { + refcounts[src] += 1; + } + } + + rs_z = -1; + for (int i = min; i <= max; ++i) { + if (refcounts[i] == 0) { + rs_z = i; + break; + } + } + + for (int i = min; i <= max; ++i) { + if (cells[i].src < 0) { + GGML_ASSERT(rs_z >= 0); + cells[i].src0 = rs_z; + } else { + // Stage the source ids for all used cells to allow correct seq_* behavior + // and still make these values available when setting the inputs + cells[i].src0 = cells[i].src; + } + cells[i].src = i; // avoid moving or clearing twice + } + } + + // allow getting the range of used cells, from head to head + n + head = min; + n = max - min + 1; + used = std::count_if(cells.begin(), cells.end(), + [](const kv_cell & cell){ return !cell.is_empty(); }); + + // sanity check + return n >= n_seqs; +} + +bool llama_kv_cache_recurrent::get_can_shift() const { + // shifting the pos is trivial for recurrent models + return true; +} + +size_t llama_kv_cache_recurrent::total_size() const { + size_t size = 0; + for (const auto & buf : bufs) { + size += ggml_backend_buffer_get_size(buf.get()); + } + + return size; +} + +size_t llama_kv_cache_recurrent::size_k_bytes() const { + size_t size_k_bytes = 0; + + for (const auto & k : k_l) { + size_k_bytes += ggml_nbytes(k); + } + + return size_k_bytes; +} + +size_t llama_kv_cache_recurrent::size_v_bytes() const { + size_t size_v_bytes = 0; + + for (const auto & v : v_l) { + size_v_bytes += ggml_nbytes(v); + } + + return size_v_bytes; +} + +void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + std::vector> cell_ranges; // ranges, from inclusive, to exclusive + uint32_t cell_count = 0; + + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id (or all, when -1) + uint32_t cell_range_begin = size; + for (uint32_t i = 0; i < size; ++i) { + const auto & cell = cells[i]; + if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { + ++cell_count; + if (cell_range_begin == size) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != size) { + cell_ranges.emplace_back(cell_range_begin, i); + cell_range_begin = size; + } + } + } + if (cell_range_begin != size) { + cell_ranges.emplace_back(cell_range_begin, size); + } + + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + + io.write(&cell_count, sizeof(cell_count)); + + state_write_meta(io, cell_ranges, seq_id); + state_write_data(io, cell_ranges); +} + +void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + uint32_t cell_count; + io.read_to(&cell_count, sizeof(cell_count)); + + bool res = true; + + res = res && state_read_meta(io, cell_count, seq_id); + res = res && state_read_data(io, cell_count); + + if (!res) { + if (seq_id == -1) { + clear(true); + } else { + seq_rm(seq_id, -1, -1); + } + throw std::runtime_error("failed to restore kv cache"); + } +} + +void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { + for (const auto & range : cell_ranges) { + for (uint32_t i = range.first; i < range.second; ++i) { + const auto & cell = cells[i]; + const llama_pos pos = cell.pos; + const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0; + + io.write(&pos, sizeof(pos)); + io.write(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id) { + for (auto seq_id : cell.seq_id) { + io.write(&seq_id, sizeof(seq_id)); + } + } + } + } +} + +void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { + const uint32_t v_trans = 0; + const uint32_t n_layer = hparams.n_layer; + + io.write(&v_trans, sizeof(v_trans)); + io.write(&n_layer, sizeof(n_layer)); + + std::vector tmp_buf; + + // Iterate and write all the keys first, each row is a cell + // Get whole range at a time + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Write key type + const int32_t k_type_i = (int32_t)k_l[il]->type; + io.write(&k_type_i, sizeof(k_type_i)); + + // Write row size of key + const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + io.write(&k_size_row, sizeof(k_size_row)); + + // Read each range of cells of k_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * k_size_row; + io.write_tensor(k_l[il], range.first * k_size_row, buf_size); + } + } + + if (!v_trans) { + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)v_l[il]->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write row size of value + const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); + io.write(&v_size_row, sizeof(v_size_row)); + + // Read each range of cells of v_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * v_size_row; + io.write_tensor(v_l[il], range.first * v_size_row, buf_size); + } + } + } else { + // When v is transposed, we also need the element size and get the element ranges from each row + const uint32_t kv_size = size; + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)v_l[il]->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write element size + const uint32_t v_size_el = ggml_type_size(v_l[il]->type); + io.write(&v_size_el, sizeof(v_size_el)); + + // Write GQA embedding size + io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); + + // For each row, we get the element values of each cell + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + // Read each range of cells of v_size_el length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t src_offset = (range.first + j * kv_size) * v_size_el; + const size_t buf_size = range_size * v_size_el; + io.write_tensor(v_l[il], src_offset, buf_size); + } + } + } + } +} + +bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { + if (dest_seq_id != -1) { + // single sequence + + seq_rm(dest_seq_id, -1, -1); + + llama_sbatch sbatch; + llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false); + + batch.n_tokens = cell_count; + batch.n_seq_tokens = cell_count; + batch.n_seqs = 1; + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id != 0) { + LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); + return false; + } + + batch.pos[i] = pos; + } + batch.n_seq_id[0] = 1; + batch.seq_id[0] = &dest_seq_id; + + if (!find_slot(batch)) { + LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); + return false; + } + + // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) + // Assume that this is one contiguous block of cells + GGML_ASSERT(head + cell_count <= size); + GGML_ASSERT(cells[head].pos == batch.pos[0]); + GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]); + GGML_ASSERT(cells[head].has_seq_id(dest_seq_id)); + GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id)); + } else { + // whole KV cache restore + + if (cell_count > size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); + return false; + } + + clear(true); + + for (uint32_t i = 0; i < cell_count; ++i) { + kv_cell & cell = cells[i]; + + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + cell.pos = pos; + + for (uint32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id; + io.read_to(&seq_id, sizeof(seq_id)); + + // TODO: llama_kv_cache_recurrent should have a notion of max sequences + //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { + if (seq_id < 0) { + //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id); + return false; + } + + cell.seq_id.insert(seq_id); + + int32_t & tail = cells[seq_id].tail; + if (tail != -1) { + LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail); + return false; + } + tail = i; + } + } + + head = 0; + used = cell_count; + } + + for (uint32_t i = 0; i < cell_count; ++i) { + uint32_t cell_id = head + i; + // make sure the recurrent states will keep their restored state + cells[cell_id].src = cell_id; + } + + return true; +} + +bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) { + uint32_t v_trans; + uint32_t n_layer; + io.read_to(&v_trans, sizeof(v_trans)); + io.read_to(&n_layer, sizeof(n_layer)); + + if (n_layer != hparams.n_layer) { + LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); + return false; + } + if (cell_count > size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size); + return false; + } + if (false != (bool) v_trans) { + LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); + return false; + } + + // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Read type of key + int32_t k_type_i_ref; + io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); + const int32_t k_type_i = (int32_t) k_l[il]->type; + if (k_type_i != k_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); + return false; + } + + // Read row size of key + uint64_t k_size_row_ref; + io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); + const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + if (k_size_row != k_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); + } + } + + if (!v_trans) { + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)v_l[il]->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read row size of value + uint64_t v_size_row_ref; + io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); + const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); + if (v_size_row != v_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the values for the whole cell range + ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); + } + } + } else { + // For each layer, read the values for each cell (transposed) + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)v_l[il]->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read element size of value + uint32_t v_size_el_ref; + io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); + const size_t v_size_el = ggml_type_size(v_l[il]->type); + if (v_size_el != v_size_el_ref) { + LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); + return false; + } + + // Read GQA embedding size + uint32_t n_embd_v_gqa_ref; + io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); + if (n_embd_v_gqa != n_embd_v_gqa_ref) { + LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); + return false; + } + + if (cell_count) { + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (head + j * size) * v_size_el; + ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + } + } + } + } + + return true; +} + +// +// llama_kv_cache_recurrent_state +// + +llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(llama_memory_status status) : status(status) {} + +llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) { +} + +llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv, + llama_sbatch sbatch, + std::vector ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {} + +llama_kv_cache_recurrent_state::~llama_kv_cache_recurrent_state() = default; + +bool llama_kv_cache_recurrent_state::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_recurrent_state::apply() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + kv->find_slot(ubatches[i_next]); + + return true; +} + +std::vector & llama_kv_cache_recurrent_state::out_ids() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return sbatch.out_ids; +} + +llama_memory_status llama_kv_cache_recurrent_state::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_recurrent_state::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return ubatches[i_next]; +} + +uint32_t llama_kv_cache_recurrent_state::get_n_kv() const { + return is_full ? kv->size : kv->n; +} + +uint32_t llama_kv_cache_recurrent_state::get_head() const { + return is_full ? 0 : kv->head; +} + +int32_t llama_kv_cache_recurrent_state::get_rs_z() const { + return is_full ? 0 : kv->rs_z; +} + +uint32_t llama_kv_cache_recurrent_state::get_size() const { + return kv->size; +} + +ggml_tensor * llama_kv_cache_recurrent_state::get_k_l(int32_t il) const { + return kv->k_l[il]; +} + +ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const { + return kv->v_l[il]; +} + +int32_t llama_kv_cache_recurrent_state::s_copy(int i) const { + return kv->cells[i + kv->head].src0; +} diff --git a/src/llama-kv-cache-recurrent.h b/src/llama-kv-cache-recurrent.h new file mode 100644 index 000000000..f9b01a651 --- /dev/null +++ b/src/llama-kv-cache-recurrent.h @@ -0,0 +1,184 @@ +#pragma once + +#include "llama-batch.h" +#include "llama-graph.h" +#include "llama-memory.h" + +#include +#include + +// +// llama_kv_cache_recurrent +// + +// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i +// see the implementation of llama_kv_cache_unified_state_i for an example how to do it +class llama_kv_cache_recurrent : public llama_memory_i { +public: + llama_kv_cache_recurrent( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max); + + ~llama_kv_cache_recurrent() = default; + + // + // llama_memory_i + // + + llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_all) override; + + llama_memory_state_ptr init_full() override; + + llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; + + void clear(bool data) override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + bool prepare(const std::vector & ubatches); + + // find a contiguous slot of kv cells and emplace the ubatch there + bool find_slot(const llama_ubatch & ubatch); + + bool get_can_shift() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + + uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) + uint32_t size = 0; // total number of cells, shared across all sequences + uint32_t used = 0; // used cells (i.e. at least one seq_id) + + // computed before each graph build + uint32_t n = 0; + + // first zero-ed state + int32_t rs_z = -1; + + // TODO: optimize for recurrent state needs + struct kv_cell { + llama_pos pos = -1; + int32_t src = -1; // used to know where states should be copied from + int32_t src0 = -1; // like src, but only used when setting the inputs (allowing to copy once) + int32_t tail = -1; + + std::set seq_id; + + bool has_seq_id(const llama_seq_id & id) const { + return seq_id.find(id) != seq_id.end(); + } + + bool is_empty() const { + return seq_id.empty(); + } + + bool is_same_seq(const kv_cell & other) const { + return seq_id == other.seq_id; + } + }; + + std::vector cells; + + std::vector k_l; // per layer + std::vector v_l; + +private: + //const llama_model & model; + const llama_hparams & hparams; + + const uint32_t n_seq_max = 1; + + std::vector ctxs; + std::vector bufs; + + size_t total_size() const; + + size_t size_k_bytes() const; + size_t size_v_bytes() const; + + void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; + void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; + + bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); + bool state_read_data(llama_io_read_i & io, uint32_t cell_count); +}; + +class llama_kv_cache_recurrent_state : public llama_memory_state_i { +public: + // used for errors + llama_kv_cache_recurrent_state(llama_memory_status status); + + // used to create a full-cache state + llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv); + + // used to create a state from a batch + llama_kv_cache_recurrent_state( + llama_memory_status status, + llama_kv_cache_recurrent * kv, + llama_sbatch sbatch, + std::vector ubatches); + + virtual ~llama_kv_cache_recurrent_state(); + + // + // llama_memory_state_i + // + + bool next() override; + bool apply() override; + + std::vector & out_ids() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_recurrent_state specific API + // + + uint32_t get_n_kv() const; + uint32_t get_head() const; + int32_t get_rs_z() const; + uint32_t get_size() const; + + ggml_tensor * get_k_l(int32_t il) const; + ggml_tensor * get_v_l(int32_t il) const; + + int32_t s_copy(int i) const; + +private: + const llama_memory_status status; + + llama_kv_cache_recurrent * kv; + + llama_sbatch sbatch; + + size_t i_next = 0; + + std::vector ubatches; + + // + // data needed for building the compute graph for the current ubatch: + // TODO: extract all the state like `head` and `n` here + // + + const bool is_full = false; +}; diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp new file mode 100644 index 000000000..a4a4c2b1b --- /dev/null +++ b/src/llama-kv-cache-unified-iswa.cpp @@ -0,0 +1,285 @@ +#include "llama-kv-cache-unified-iswa.h" + +#include "llama-impl.h" +#include "llama-batch.h" +#include "llama-model.h" + +#include +#include + +// +// llama_kv_cache_unified_iswa +// + +llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool swa_full, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_ubatch, + uint32_t n_pad) : hparams(model.hparams) { + llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); }; + llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); }; + + const uint32_t size_base = kv_size; + + uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad)); + + // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size + if (swa_full) { + LLAMA_LOG_WARN("%s: using full-size SWA cache (ref: %s)\n", + __func__, "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); + + size_swa = size_base; + } + + LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base); + + kv_base = std::make_unique( + model, std::move(filter_base), type_k, type_v, + v_trans, offload, size_base, n_seq_max, n_pad, + 0, LLAMA_SWA_TYPE_NONE); + + LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); + + kv_swa = std::make_unique( + model, std::move(filter_swa), type_k, type_v, + v_trans, offload, size_swa, n_seq_max, n_pad, + hparams.n_swa, hparams.swa_type); +} + +void llama_kv_cache_unified_iswa::clear(bool data) { + kv_base->clear(data); + kv_swa ->clear(data); +} + +bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + bool res = true; + + res = res & kv_base->seq_rm(seq_id, p0, p1); + res = res & kv_swa ->seq_rm(seq_id, p0, p1); + + return res; +} + +void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1); + kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1); +} + +void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) { + kv_base->seq_keep(seq_id); + kv_swa ->seq_keep(seq_id); +} + +void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + kv_base->seq_add(seq_id, p0, p1, shift); + kv_swa ->seq_add(seq_id, p0, p1, shift); +} + +void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + kv_base->seq_div(seq_id, p0, p1, d); + kv_swa ->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const { + // the base cache is a superset of the SWA cache, so we can just check the SWA cache + return kv_swa->seq_pos_min(seq_id); +} + +llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const { + return kv_swa->seq_pos_max(seq_id); +} + +llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) { + GGML_UNUSED(embd_all); + + // first try simple split + do { + auto sbatch = llama_sbatch(batch, hparams.n_embd, true); + + std::vector ubatches; + + while (sbatch.n_tokens > 0) { + auto ubatch = sbatch.split_simple(n_ubatch); + + ubatches.push_back(ubatch); + } + + auto heads_base = kv_base->prepare(ubatches); + if (heads_base.empty()) { + break; + } + + auto heads_swa = kv_swa->prepare(ubatches); + if (heads_swa.empty()) { + break; + } + + assert(heads_base.size() == heads_swa.size()); + + return std::make_unique( + this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches)); + } while (false); + + // if it fails, try equal split + do { + auto sbatch = llama_sbatch(batch, hparams.n_embd, false); + + std::vector ubatches; + + while (sbatch.n_tokens > 0) { + auto ubatch = sbatch.split_equal(n_ubatch); + + ubatches.push_back(ubatch); + } + + auto heads_base = kv_base->prepare(ubatches); + if (heads_base.empty()) { + break; + } + + auto heads_swa = kv_swa->prepare(ubatches); + if (heads_swa.empty()) { + break; + } + + assert(heads_base.size() == heads_swa.size()); + + return std::make_unique( + this, std::move(sbatch), std::move(heads_base), std::move(heads_swa), std::move(ubatches)); + } while (false); + + // TODO: if we fail again, we should attempt different splitting strategies + // but to do that properly, we first have to refactor the batches to be more flexible + + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); +} + +llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full() { + return std::make_unique(this); +} + +llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) { + return std::make_unique(this, lctx, optimize); +} + +bool llama_kv_cache_unified_iswa::get_can_shift() const { + return kv_base->get_size() == kv_swa->get_size(); +} + +void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + kv_base->state_write(io, seq_id); + kv_swa ->state_write(io, seq_id); +} + +void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + kv_base->state_read(io, seq_id); + kv_swa ->state_read(io, seq_id); +} + +llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const { + return kv_base.get(); +} + +llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const { + return kv_swa.get(); +} + +// +// llama_kv_cache_unified_iswa_state +// + +llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {} + +llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( + llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) { + state_base = kv->get_base()->init_full(); + state_swa = kv->get_swa ()->init_full(); + + status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status()); +} + +llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( + llama_kv_cache_unified_iswa * kv, + llama_context * lctx, + bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) { + state_base = kv->get_base()->init_update(lctx, optimize); + state_swa = kv->get_swa ()->init_update(lctx, optimize); + + status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status()); +} + +llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state( + llama_kv_cache_unified_iswa * kv, + llama_sbatch sbatch, + std::vector heads_base, + std::vector heads_swa, + std::vector ubatches) + : status(LLAMA_MEMORY_STATUS_SUCCESS), + sbatch(std::move(sbatch)), + ubatches(std::move(ubatches)) { + // note: here we copy the ubatches. not sure if this is ideal + state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches)); + state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches)); + + status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status()); +} + +llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default; + +bool llama_kv_cache_unified_iswa_state::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + state_base->next(); + state_swa ->next(); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_unified_iswa_state::apply() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + bool res = true; + + res = res & state_base->apply(); + res = res & state_swa ->apply(); + + return res; +} + +std::vector & llama_kv_cache_unified_iswa_state::out_ids() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return sbatch.out_ids; +} + +llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_unified_iswa_state::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return ubatches[i_next]; +} + +const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_base() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return static_cast(state_base.get()); +} + +const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return static_cast(state_swa.get()); +} diff --git a/src/llama-kv-cache-unified-iswa.h b/src/llama-kv-cache-unified-iswa.h new file mode 100644 index 000000000..6e941e1a4 --- /dev/null +++ b/src/llama-kv-cache-unified-iswa.h @@ -0,0 +1,133 @@ +#pragma once + +#include "llama-kv-cache-unified.h" + +#include + +// +// llama_kv_cache_unified_iswa +// + +// utilizes two instances of llama_kv_cache_unified +// the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers + +class llama_kv_cache_unified_iswa : public llama_memory_i { +public: + llama_kv_cache_unified_iswa( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool swa_full, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_ubatch, + uint32_t n_pad); + + ~llama_kv_cache_unified_iswa() = default; + + // + // llama_memory_i + // + + llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_all) override; + + llama_memory_state_ptr init_full() override; + + llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; + + bool get_can_shift() const override; + + void clear(bool data) override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + + // + // llama_kv_cache_unified_iswa specific API + // + + llama_kv_cache_unified * get_base() const; + llama_kv_cache_unified * get_swa () const; + +private: + const llama_hparams & hparams; + + std::unique_ptr kv_base; + std::unique_ptr kv_swa; +}; + +class llama_kv_cache_unified_iswa_state : public llama_memory_state_i { +public: + // used for errors + llama_kv_cache_unified_iswa_state(llama_memory_status status); + + // used to create a full-cache state + llama_kv_cache_unified_iswa_state( + llama_kv_cache_unified_iswa * kv); + + // used to create an update state + llama_kv_cache_unified_iswa_state( + llama_kv_cache_unified_iswa * kv, + llama_context * lctx, + bool optimize); + + // used to create a state from a batch + llama_kv_cache_unified_iswa_state( + llama_kv_cache_unified_iswa * kv, + llama_sbatch sbatch, + std::vector heads_base, + std::vector heads_swa, + std::vector ubatches); + + virtual ~llama_kv_cache_unified_iswa_state(); + + // + // llama_memory_state_i + // + + bool next() override; + bool apply() override; + + std::vector & out_ids() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_unified_iswa_state specific API + // + + const llama_kv_cache_unified_state * get_base() const; + const llama_kv_cache_unified_state * get_swa() const; + +private: + llama_memory_status status; + + //llama_kv_cache_unified_iswa * kv; + + llama_sbatch sbatch; + + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector ubatches; + + llama_memory_state_ptr state_base; + llama_memory_state_ptr state_swa; +}; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp new file mode 100644 index 000000000..3b3767985 --- /dev/null +++ b/src/llama-kv-cache-unified.cpp @@ -0,0 +1,1835 @@ +#include "llama-kv-cache-unified.h" + +#include "llama-impl.h" +#include "llama-io.h" +#include "llama-model.h" +#include "llama-context.h" + +#include +#include +#include +#include +#include +#include + +// +// llama_kv_cache_unified +// + +llama_kv_cache_unified::llama_kv_cache_unified( + const llama_model & model, + layer_filter_cb && filter, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type) : + model(model), hparams(model.hparams), v_trans(v_trans), + n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { + + GGML_ASSERT(kv_size % n_pad == 0); + + // create a context for each buffer type + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ size_t(2u*hparams.n_layer*ggml_tensor_overhead()), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return nullptr; + } + + ctx_map[buft] = ctx; + ctxs.emplace_back(ctx); + + return ctx; + } + + return it->second; + }; + + head = 0; + + cells.resize(kv_size); + + for (uint32_t il = 0; il < hparams.n_layer; il++) { + if (filter && !filter(il)) { + LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il); + continue; + } + + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + const char * dev_name = "CPU"; + + ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); + + if (offload) { + auto * dev = model.dev_layer(il); + buft = ggml_backend_dev_buffer_type(dev); + + dev_name = ggml_backend_dev_name(dev); + } + + LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, il, dev_name); + + ggml_context * ctx = ctx_for_buft(buft); + if (!ctx) { + throw std::runtime_error("failed to create ggml context for kv cache"); + } + + ggml_tensor * k; + ggml_tensor * v; + + k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size); + v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size); + + ggml_format_name(k, "cache_k_l%d", il); + ggml_format_name(v, "cache_v_l%d", il); + + map_layer_ids[il] = layers.size(); + layers.push_back({ il, k, v }); + } + + // allocate tensors and initialize the buffers to avoid NaNs in the padding + for (auto it : ctx_map) { + auto * buft = it.first; + auto * ctx = it.second; + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + throw std::runtime_error("failed to allocate buffer for kv cache"); + } + + LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + + ggml_backend_buffer_clear(buf, 0); + bufs.emplace_back(buf); + } + + { + const size_t memory_size_k = size_k_bytes(); + const size_t memory_size_v = size_v_bytes(); + + LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + } + + const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG"); + debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0; +} + +void llama_kv_cache_unified::clear(bool data) { + cells.reset(); + + head = 0; + + if (data) { + for (auto & buf : bufs) { + ggml_backend_buffer_clear(buf.get(), 0); + } + } +} + +bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + uint32_t new_head = cells.size(); + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + if (seq_id >= 0) { + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id) && cells.seq_rm(i, seq_id)) { + if (new_head == cells.size()) { + new_head = i; + } + } + } + } else { + // match any sequence + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + cells.rm(i); + + if (new_head == cells.size()) { + new_head = i; + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cells.size() && new_head < head) { + head = new_head; + } + + return true; +} + +void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + if (seq_id_src == seq_id_dst) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id_src)) { + cells.seq_add(i, seq_id_dst); + } + } +} + +void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { + uint32_t new_head = cells.size(); + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (cells.seq_keep(i, seq_id)) { + if (new_head == cells.size()) { + new_head = i; + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cells.size() && new_head < head) { + head = new_head; + } +} + +void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + if (shift == 0) { + return; + } + + uint32_t new_head = cells.size(); + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over all cells. + if (p0 == p1) { + return; + } + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id)) { + if (cells.pos_add(i, shift)) { + if (new_head == cells.size()) { + new_head = i; + } + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + // Otherwise we just start the next search from the beginning. + head = new_head != cells.size() ? new_head : 0; +} + +void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + if (d == 1) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over the cache. + if (p0 == p1) { + return; + } + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id)) { + cells.pos_div(i, d); + } + } +} + +llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { + return cells.seq_pos_min(seq_id); +} + +llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { + return cells.seq_pos_max(seq_id); +} + +llama_memory_state_ptr llama_kv_cache_unified::init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_all) { + GGML_UNUSED(embd_all); + + do { + auto sbatch = llama_sbatch(batch, hparams.n_embd, true); + + std::vector ubatches; + while (sbatch.n_tokens > 0) { + ubatches.push_back(sbatch.split_simple(n_ubatch)); + } + + auto heads = prepare(ubatches); + if (heads.empty()) { + break; + } + + return std::make_unique( + this, std::move(sbatch), std::move(heads), std::move(ubatches)); + } while (false); + + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); +} + +llama_memory_state_ptr llama_kv_cache_unified::init_full() { + return std::make_unique(this); +} + +llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) { + bool do_shift = get_has_shift(); + + defrag_info dinfo; + + // see if we need to defrag + { + bool do_defrag = optimize; + + const auto thold = lctx->get_cparams().defrag_thold; + + if (!do_defrag && thold > 0.0f) { + const auto n_kv = cells.used_max_p1(); + + // - do not defrag small contexts (i.e. < 2048 tokens) + // - count the padding towards the number of used tokens + const float fragmentation = n_kv >= 2048 ? std::max(0.0f, 1.0f - (float(cells.get_used() + n_pad)/n_kv)) : 0.0f; + + if (fragmentation > thold) { + LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); + + do_defrag = true; + } + } + + if (do_defrag) { + dinfo = defrag_prepare(lctx->graph_max_nodes()); + } + } + + return std::make_unique(this, lctx, do_shift, std::move(dinfo)); +} + +llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector & ubatches) { + llama_kv_cache_unified::ubatch_heads res; + + struct state { + uint32_t head_old; // old position of the head, before placing the ubatch + uint32_t head_new; // new position of the head, after placing the ubatch + + llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch + }; + + // remember the old state of the cells so we can restore it in the end + std::vector states; + + bool success = true; + + for (const auto & ubatch : ubatches) { + // only find a suitable slot for the ubatch. don't modify the cells yet + const int32_t head_new = find_slot(ubatch); + if (head_new < 0) { + success = false; + break; + } + + // remeber the position that we found + res.push_back(head_new); + + // store the old state of the cells in the recovery stack + states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)}); + + // now emplace the ubatch + apply_ubatch(head_new, ubatch); + } + + // iterate backwards and restore the cells to their original state + for (auto it = states.rbegin(); it != states.rend(); ++it) { + cells.set(it->head_new, it->cells); + head = it->head_old; + } + + if (!success) { + return {}; + } + + return res; +} + +bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) { + bool updated = false; + + auto * sched = lctx->get_sched(); + + if (do_shift) { + if (!get_can_shift()) { + GGML_ABORT("The current KV cache / model configuration does not support K-shift"); + } + + LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__); + + // apply K-shift if needed + if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { + ggml_backend_sched_reset(sched); + + auto * gf = lctx->graph_init(); + + auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf); + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__); + return updated; + } + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__); + return updated; + } + + res->set_inputs(nullptr); + + if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: failed to compute K-shift\n", __func__); + return updated; + } + + updated = true; + } + + cells.reset_shift(); + } + + if (!dinfo.empty()) { + LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); + + // apply moves: + { + const auto n_kv = dinfo.ids.size(); + + for (uint32_t i = 0; i < n_kv; ++i) { + assert(dinfo.ids[i] <= n_kv); + + if (dinfo.ids[i] == n_kv || dinfo.ids[i] == i) { + continue; + } + + cells.mv(i, dinfo.ids[i]); + } + + // reset the head so we can find the first free slot during the next ubatch + head = 0; + } + + ggml_backend_sched_reset(sched); + + auto * gf = lctx->graph_init(); + + auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo); + if (!res) { + LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__); + return updated; + } + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__); + return updated; + } + + res->set_inputs(nullptr); + + if (lctx->graph_compute(gf, false) != GGML_STATUS_SUCCESS) { + LLAMA_LOG_ERROR("%s: failed to compute defrag\n", __func__); + return updated; + } + + updated = true; + } + + return updated; +} + +int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const { + const uint32_t n_tokens = ubatch.n_tokens; + + uint32_t head_cur = this->head; + + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (head_cur > cells.get_used() + 2*ubatch.n_tokens) { + head_cur = 0; + } + + if (n_tokens > cells.size()) { + LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); + return -1; + } + + if (debug > 0) { + LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa); + + if ((debug == 2 && n_swa > 0) || debug > 2) { + std::string ss; + for (uint32_t i = 0; i < cells.size(); ++i) { + if (cells.is_empty(i)) { + ss += '.'; + } else { + assert(cells.seq_count(i) >= 1); + + if (cells.seq_count(i) == 1) { + ss += std::to_string(cells.seq_get(i)); + } else { + ss += 'M'; + } + } + if (i%256 == 255) { + ss += " *"; + ss += '\n'; + } + } + LLAMA_LOG_DEBUG("\n%s\n", ss.c_str()); + } + + if ((debug == 2 && n_swa > 0) || debug > 2) { + std::string ss; + for (uint32_t i = 0; i < cells.size(); ++i) { + std::string cur; + if (cells.is_empty(i)) { + cur = '.'; + } else { + cur = std::to_string(cells.pos_get(i)); + } + const int n = cur.size(); + for (int j = 0; j < 5 - n; ++j) { + cur += ' '; + } + ss += cur; + if (i%256 == 255) { + ss += " *"; + } + if (i%64 == 63) { + ss += '\n'; + } + } + LLAMA_LOG_DEBUG("\n%s\n", ss.c_str()); + } + + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + if (cells.seq_pos_min(s) < 0) { + continue; + } + + LLAMA_LOG_DEBUG("%s: min[%d] = %5d, max[%d] = %5d\n", __func__, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s)); + } + } + + uint32_t n_tested = 0; + + while (true) { + if (head_cur + n_tokens > cells.size()) { + n_tested += cells.size() - head_cur; + head_cur = 0; + continue; + } + + bool found = true; + for (uint32_t i = 0; i < n_tokens; i++) { + //const llama_pos pos = ubatch.pos[i]; + //const llama_seq_id seq_id = ubatch.seq_id[i][0]; + + // can we use this cell? either: + // - the cell is empty + // - the cell is occupied only by one sequence: + // - (disabled) mask causally, if the sequence is the same as the one we are inserting + // - mask SWA, using current max pos for that sequence in the cache + // always insert in the cell with minimum pos + bool can_use = cells.is_empty(head_cur + i); + + if (!can_use && cells.seq_count(head_cur + i) == 1) { + const llama_pos pos_cell = cells.pos_get(head_cur + i); + + // (disabled) causal mask + // note: it's better to purge any "future" tokens beforehand + //if (cells.seq_has(head_cur + i, seq_id)) { + // can_use = pos_cell >= pos; + //} + + if (!can_use) { + const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i); + + // SWA mask + if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { + can_use = true; + } + } + } + + if (!can_use) { + found = false; + head_cur += i + 1; + n_tested += i + 1; + break; + } + } + + if (found) { + break; + } + + if (n_tested >= cells.size()) { + //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return -1; + } + } + + return head_cur; +} + +void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) { + if (debug > 0) { + LLAMA_LOG_DEBUG("%s: ubatch info:\n", __func__); + LLAMA_LOG_DEBUG("%s: n_tokens = %d, equal_seqs = %d\n", __func__, ubatch.n_tokens, ubatch.equal_seqs); + LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d, n_seqs = %d\n", __func__, ubatch.n_seq_tokens, ubatch.n_seqs); + } + + // keep track of the max sequence position that we would overwrite with this ubatch + // for non-SWA cache, this would be always empty + llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + seq_pos_max_rm[s] = -1; + } + + for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { + for (uint32_t j = 0; j < ubatch.n_seq_tokens; ++j) { + const uint32_t idx = s*ubatch.n_seq_tokens + j; + + if (!cells.is_empty(head_cur + idx)) { + assert(cells.seq_count(head_cur + idx) == 1); + + const llama_seq_id seq_id = cells.seq_get(head_cur + idx); + const llama_pos pos = cells.pos_get(head_cur + idx); + + seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); + + cells.rm(head_cur + idx); + } + + cells.pos_set(head_cur + idx, ubatch.pos[idx]); + + // TODO: fix indexing [UBATCH_IDX] + for (int32_t i = 0; i < ubatch.n_seq_id[s]; i++) { + cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]); + } + } + } + + // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence + // will be present in the cache. so we have to purge any position which is less than those we would overwrite + // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + if (seq_pos_max_rm[s] == -1) { + continue; + } + + if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) { + LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n", + __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s); + + seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1); + } + } + // move the head at the end of the slot + head = head_cur + ubatch.n_tokens; +} + +bool llama_kv_cache_unified::get_can_shift() const { + return true; +} + +uint32_t llama_kv_cache_unified::get_size() const { + return cells.size(); +} + +bool llama_kv_cache_unified::get_has_shift() const { + return cells.get_has_shift(); +} + +uint32_t llama_kv_cache_unified::get_n_kv() const { + return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); +} + +ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const { + const int32_t ikv = map_layer_ids.at(il); + + auto * k = layers[ikv].k; + + return ggml_view_3d(ctx, k, + hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, + ggml_row_size(k->type, hparams.n_embd_head_k), + ggml_row_size(k->type, hparams.n_embd_k_gqa(il)), + 0); +} + +ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const { + const int32_t ikv = map_layer_ids.at(il); + + auto * v = layers[ikv].v; + + if (!v_trans) { + // note: v->nb[1] <= v->nb[2] + return ggml_view_3d(ctx, v, + hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, + ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2] + 0); + } + + // note: v->nb[1] > v->nb[2] + return ggml_view_3d(ctx, v, + n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, + ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, v->ne[1]), // v->nb[2] + 0); +} + +ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const { + const int32_t ikv = map_layer_ids.at(il); + + auto * k = layers[ikv].k; + + const int64_t n_tokens = k_cur->ne[2]; + + ggml_tensor * k_view = ggml_view_1d(ctx, k, + n_tokens*hparams.n_embd_k_gqa(il), + ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur); + + return ggml_cpy(ctx, k_cur, k_view); +} + +ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const { + const int32_t ikv = map_layer_ids.at(il); + + auto * v = layers[ikv].v; + + const int64_t n_tokens = v_cur->ne[2]; + + v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens); + + ggml_tensor * v_view = nullptr; + + if (!v_trans) { + v_view = ggml_view_1d(ctx, v, + n_tokens*hparams.n_embd_v_gqa(il), + ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur); + } else { + // note: the V cache is transposed when not using flash attention + v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il), + (v->ne[1])*ggml_element_size(v), + (head_cur)*ggml_element_size(v)); + + v_cur = ggml_transpose(ctx, v_cur); + } + + return ggml_cpy(ctx, v_cur, v_view); +} + +void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { + const uint32_t n_tokens = ubatch->n_tokens; + const uint32_t n_seq_tokens = ubatch->n_seq_tokens; + const uint32_t n_seqs = ubatch->n_seqs; + + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + float * data = (float *) dst->data; + + const int64_t n_kv = dst->ne[0]; + + // Use only the previous KV cells of the correct sequence for each token of the ubatch. + // It's assumed that if a token in the batch has multiple sequences, they are equivalent. + // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch: + // Causal mask: + // xxx------- + // xxxx------ + // xxxxx----- + // Non-causal mask: + // xxxxx----- + // xxxxx----- + // xxxxx----- + // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 + for (uint32_t h = 0; h < 1; ++h) { + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[s][0]; + + for (uint32_t j = 0; j < n_seq_tokens; ++j) { + const uint32_t idx = s*n_seq_tokens + j; + + const llama_pos p1 = ubatch->pos[idx]; + + for (uint32_t i = 0; i < n_kv; ++i) { + float f = 0.0f; + + bool masked = false; + + if (cells.is_empty(i)) { + masked = true; + } else { + const llama_pos p0 = cells.pos_get(i); + + // mask the token if not the same sequence + masked = masked || (!cells.seq_has(i, seq_id)); + + // mask future tokens + masked = masked || (causal_attn && p0 > p1); + + // apply SWA if any + masked = masked || (is_masked_swa(p0, p1)); + + if (!masked && hparams.use_alibi) { + f = -std::abs(p0 - p1); + } + } + + if (masked) { + f = -INFINITY; + } + + data[h*(n_kv*n_tokens) + idx*n_kv + i] = f; + } + } + } + + // mask padded tokens + if (data) { + for (uint32_t j = n_tokens; j < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++j) { + for (uint32_t i = 0; i < n_kv; ++i) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + } +} + +void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + + int32_t * data = (int32_t *) dst->data; + + for (uint32_t i = 0; i < cells.size(); ++i) { + data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i); + } +} + +void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { + const int64_t n_tokens = ubatch->n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing + + int32_t * data = (int32_t *) dst->data; + + const int32_t n_kv = dst->ne[0]; + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_kv; ++i) { + // the position when the cells is empty is irrelevant - it will be masked out later in the attention + const llama_pos p0 = cells.is_empty(i) ? -1 : cells.pos_get(i); + + data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(p0, ubatch->pos[j], hparams.n_rel_attn_bkts, false); + } + } + } +} + +size_t llama_kv_cache_unified::total_size() const { + size_t size = 0; + + for (const auto & buf : bufs) { + size += ggml_backend_buffer_get_size(buf.get()); + } + + return size; +} + +size_t llama_kv_cache_unified::size_k_bytes() const { + size_t size_k_bytes = 0; + + for (const auto & layer : layers) { + size_k_bytes += ggml_nbytes(layer.k); + } + + return size_k_bytes; +} + +size_t llama_kv_cache_unified::size_v_bytes() const { + size_t size_v_bytes = 0; + + for (const auto & layer : layers) { + size_v_bytes += ggml_nbytes(layer.v); + } + + return size_v_bytes; +} + +ggml_tensor * llama_kv_cache_unified::build_rope_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * shift, + ggml_tensor * factors, + float freq_base, + float freq_scale) const { + const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; + + const auto & yarn_ext_factor = cparams.yarn_ext_factor; + const auto & yarn_beta_fast = cparams.yarn_beta_fast; + const auto & yarn_beta_slow = cparams.yarn_beta_slow; + + const auto & n_rot = hparams.n_rot; + const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE + // @ngxson : this is a workaround + // for M-RoPE, we want to rotate the whole vector when doing KV shift + // a normal RoPE should work, we just need to use the correct ordering + // ref: https://github.com/ggml-org/llama.cpp/pull/13870 + ? LLAMA_ROPE_TYPE_NEOX + : hparams.rope_type; + + // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly. + // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. + const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 + ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) + : cparams.yarn_attn_factor; + + ggml_tensor * tmp; + + if (ggml_is_quantized(cur->type)) { + // dequantize to f32 -> RoPE -> quantize back + tmp = ggml_cast(ctx, cur, GGML_TYPE_F32); + + tmp = ggml_rope_ext(ctx, tmp, + shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + + tmp = ggml_cpy(ctx, tmp, cur); + } else { + // we rotate only the first n_rot dimensions + tmp = ggml_rope_ext_inplace(ctx, cur, + shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + } + + return tmp; +} + +class llm_graph_input_k_shift : public llm_graph_input_i { +public: + llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} + virtual ~llm_graph_input_k_shift() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * k_shift; // I32 [kv_size] + + const llama_kv_cache_unified * kv_self; +}; + +void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { + GGML_UNUSED(ubatch); + + if (k_shift) { + kv_self->set_input_k_shift(k_shift); + } +} + +llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const { + auto res = std::make_unique(); + + const auto & n_embd_head_k = hparams.n_embd_head_k; + //const auto & n_embd_head_v = hparams.n_embd_head_v; + + auto inp = std::make_unique(this); + + inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size()); + ggml_set_input(inp->k_shift); + + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const int64_t n_head_kv = hparams.n_head_kv(il); + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + ggml_tensor * k = + ggml_view_3d(ctx, layer.k, + n_embd_head_k, n_head_kv, cells.size(), + ggml_row_size(layer.k->type, n_embd_head_k), + ggml_row_size(layer.k->type, n_embd_k_gqa), + 0); + + ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); + + ggml_build_forward_expand(gf, cur); + } + + res->add_input(std::move(inp)); + + return res; +} + +llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf, + const defrag_info & dinfo) const { + auto res = std::make_unique(); + + const auto & ids = dinfo.ids; + +#if 0 + // CPU defrag + // + // TODO: optimizations are possible: + // - multiple threads + // - avoid copying to the host memory when already there + // + // likely not worth the effort, as we have ggml_graph based defrag + // + + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + + const uint32_t kv_size = size; + + std::vector buf_k; + std::vector buf_v; + + for (uint32_t il = 0; il < n_layer; ++il) { + const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size); + + const size_t v_size_el = ggml_type_size(v_l[il]->type); + const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size); + + buf_k.resize(k_size); + buf_v.resize(v_size); + + ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size()); + ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size()); + + // batch move [i, i+nm) to [id, id+nm) + // note: cells can move only to a lower index + for (uint32_t i = 0; i < n_kv; ++i) { + const uint32_t id = ids[i]; + + if (i == id || id == n_kv) { + continue; + } + + uint32_t nm = 1; + + while (i + nm < n_kv && ids[i + nm] == id + nm) { + nm++; + } + + // move keys + { + const int64_t os = i*k_size_row; + const int64_t od = id*k_size_row; + + memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row); + } + + // move values (note: they are transposed) + { + const int64_t os = i; + const int64_t od = id; + + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el); + } + } + + i += nm - 1; + } + + ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size()); + ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size()); + } +#else + for (uint32_t i = 0; i < ids.size(); ++i) { + const uint32_t id = ids[i]; + + if (i == id || id == ids.size()) { + continue; + } + + uint32_t nm = 1; + + while (i + nm < ids.size() && ids[i + nm] == id + nm) { + nm++; + } + + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + + ggml_tensor * view_k_src = ggml_view_2d(ctx, layer.k, + n_embd_k_gqa, nm, + ggml_row_size(layer.k->type, n_embd_k_gqa), + ggml_row_size(layer.k->type, n_embd_k_gqa*i)); + + ggml_tensor * view_k_dst = ggml_view_2d(ctx, layer.k, + n_embd_k_gqa, nm, + ggml_row_size(layer.k->type, n_embd_k_gqa), + ggml_row_size(layer.k->type, n_embd_k_gqa*id)); + + ggml_tensor * view_v_src; + ggml_tensor * view_v_dst; + + if (cparams.flash_attn) { + // NOTE: the V cache is not transposed when using flash attention + view_v_src = ggml_view_2d(ctx, layer.v, + n_embd_v_gqa, nm, + ggml_row_size(layer.v->type, n_embd_v_gqa), + ggml_row_size(layer.v->type, n_embd_v_gqa*i)); + + view_v_dst = ggml_view_2d(ctx, layer.v, + n_embd_v_gqa, nm, + ggml_row_size(layer.v->type, n_embd_v_gqa), + ggml_row_size(layer.v->type, n_embd_v_gqa*id)); + } else { + view_v_src = ggml_view_2d(ctx, layer.v, + nm, n_embd_v_gqa, + ggml_row_size(layer.v->type, cells.size()), + ggml_row_size(layer.v->type, i)); + + view_v_dst = ggml_view_2d(ctx, layer.v, + nm, n_embd_v_gqa, + ggml_row_size(layer.v->type, cells.size()), + ggml_row_size(layer.v->type, id)); + } + + ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst)); + ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst)); + } + + i += nm - 1; + } + + //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes); +#endif + + return res; +} + +llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const { + const uint32_t n_layer = layers.size(); + + const uint32_t n_kv = cells.used_max_p1(); + const uint32_t n_used = cells.get_used(); + + assert(n_used <= n_kv); + + //const int64_t t_start = ggml_time_us(); + + // number of cells moved + uint32_t n_moves = 0; + + // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag) + // - source view, destination view, copy operation + // - x2 for keys and values + //const uint32_t max_moves = max_nodes()/(6*n_layer); + // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 + const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer); + + // determine which KV cells to move where + defrag_info res; + auto & ids = res.ids; + + ids.resize(n_kv, n_kv); + + for (uint32_t i0 = 0; i0 < n_used; ++i0) { + if (!cells.is_empty(i0)) { + ids[i0] = i0; + + continue; + } + + // found a hole - fill it with data from the end of the cache + + uint32_t nh = 1; + + // determine the size of the hole + while (i0 + nh < n_used && cells.is_empty(i0 + nh)) { + nh++; + } + + uint32_t nf = 0; + uint32_t is = n_kv - 1; + + // starting from the end, find nh non-empty cells + for (; is > i0; --is) { + if (cells.is_empty(is) || ids[is] != n_kv) { + continue; + } + + // non-empty cell which is not yet moved + nf++; + + if (nf == nh) { + break; + } + } + + // this can only happen if `n_used` is not accurate, which would be a bug + GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh"); + + nf = 0; + + uint32_t i1 = is; + + // are we moving a continuous block of memory? + bool cont = false; + + // should we stop searching for the next move? + bool stop = false; + + // go back and move the nf cells to the hole + for (; i1 < n_kv; ++i1) { + if (cells.is_empty(i1) || ids[i1] != n_kv) { + if (n_moves == max_moves) { + stop = true; + break; + } + + cont = false; + continue; + } + + // this cell goes to (i0 + nf) + ids[i1] = i0 + nf; + + if (!cont) { + n_moves++; + cont = true; + } + + nf++; + + if (nf == nh) { + break; + } + } + + if (stop || n_moves == max_moves) { + break; + } + + //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh); + + i0 += nh - 1; + } + + if (n_moves == 0) { + return {}; + } + + LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves); + + LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer); + + return res; +} + +bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { + assert(p0 >= 0 && p1 >= 0); + + switch (swa_type) { + case LLAMA_SWA_TYPE_NONE: + { + } break; + case LLAMA_SWA_TYPE_STANDARD: + { + if (p1 - p0 >= (int32_t) n_swa) { + return true; + } + } break; + case LLAMA_SWA_TYPE_CHUNKED: + { + const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; + + if (p0 < pos_chunk_start) { + return true; + } + } break; + } + + return false; +} + +void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + std::vector> cell_ranges; // ranges, from inclusive, to exclusive + uint32_t cell_count = 0; + + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id (or all, when -1) + uint32_t cell_range_begin = cells.size(); + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) { + ++cell_count; + if (cell_range_begin == cells.size()) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != cells.size()) { + cell_ranges.emplace_back(cell_range_begin, i); + cell_range_begin = cells.size(); + } + } + } + + if (cell_range_begin != cells.size()) { + cell_ranges.emplace_back(cell_range_begin, cells.size()); + } + + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + + io.write(&cell_count, sizeof(cell_count)); + + state_write_meta(io, cell_ranges, seq_id); + state_write_data(io, cell_ranges); +} + +void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + uint32_t cell_count; + io.read_to(&cell_count, sizeof(cell_count)); + + bool res = true; + res = res && state_read_meta(io, cell_count, seq_id); + res = res && state_read_data(io, cell_count); + + if (!res) { + if (seq_id == -1) { + clear(true); + } else { + seq_rm(seq_id, -1, -1); + } + throw std::runtime_error("failed to restore kv cache"); + } +} + +void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { + for (const auto & range : cell_ranges) { + for (uint32_t i = range.first; i < range.second; ++i) { + std::vector seq_ids; + + for (llama_seq_id cur = 0; cur < (int) n_seq_max; ++cur) { + if (cur == seq_id || seq_id == -1) { + if (cells.seq_has(i, cur)) { + seq_ids.push_back(cur); + } + } + } + + const llama_pos pos = cells.pos_get(i); + const uint32_t n_seq_id = seq_ids.size(); + + io.write(&pos, sizeof(pos)); + io.write(&n_seq_id, sizeof(n_seq_id)); + + for (const auto & seq_id : seq_ids) { + io.write(&seq_id, sizeof(seq_id)); + } + } + } +} + +void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { + const uint32_t v_trans = this->v_trans ? 1 : 0; + const uint32_t n_layer = layers.size(); + + io.write(&v_trans, sizeof(v_trans)); + io.write(&n_layer, sizeof(n_layer)); + + std::vector tmp_buf; + + // Iterate and write all the keys first, each row is a cell + // Get whole range at a time + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Write key type + const int32_t k_type_i = (int32_t)layer.k->type; + io.write(&k_type_i, sizeof(k_type_i)); + + // Write row size of key + const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa); + io.write(&k_size_row, sizeof(k_size_row)); + + // Read each range of cells of k_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * k_size_row; + io.write_tensor(layer.k, range.first * k_size_row, buf_size); + } + } + + if (!v_trans) { + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)layer.v->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write row size of value + const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa); + io.write(&v_size_row, sizeof(v_size_row)); + + // Read each range of cells of v_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * v_size_row; + io.write_tensor(layer.v, range.first * v_size_row, buf_size); + } + } + } else { + // When v is transposed, we also need the element size and get the element ranges from each row + const uint32_t kv_size = cells.size(); + + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)layer.v->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write element size + const uint32_t v_size_el = ggml_type_size(layer.v->type); + io.write(&v_size_el, sizeof(v_size_el)); + + // Write GQA embedding size + io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); + + // For each row, we get the element values of each cell + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + // Read each range of cells of v_size_el length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t src_offset = (range.first + j * kv_size) * v_size_el; + const size_t buf_size = range_size * v_size_el; + io.write_tensor(layer.v, src_offset, buf_size); + } + } + } + } +} + +bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { + if (dest_seq_id != -1) { + // single sequence + + seq_rm(dest_seq_id, -1, -1); + + llama_sbatch sbatch; + llama_ubatch ubatch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false); + + ubatch.n_tokens = cell_count; + ubatch.n_seq_tokens = cell_count; + ubatch.n_seqs = 1; + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id != 1) { + LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); + return false; + } + + // read the sequence id, but directly discard it - we will use dest_seq_id instead + { + llama_seq_id seq_id; + io.read_to(&seq_id, sizeof(seq_id)); + } + + ubatch.pos[i] = pos; + ubatch.n_seq_id[i] = n_seq_id; + ubatch.seq_id[i] = &dest_seq_id; + } + + const auto head_cur = find_slot(ubatch); + if (head_cur < 0) { + LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); + return false; + } + + apply_ubatch(head_cur, ubatch); + + // keep the head at the old position because we will read the KV data into it in state_read_data() + head = head_cur; + + // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values) + // Assume that this is one contiguous block of cells + GGML_ASSERT(head_cur + cell_count <= cells.size()); + GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]); + GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]); + GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id)); + GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id)); + } else { + // whole KV cache restore + + if (cell_count > cells.size()) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); + return false; + } + + clear(true); + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + cells.pos_set(i, pos); + + for (uint32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id; + io.read_to(&seq_id, sizeof(seq_id)); + + if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) { + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max); + return false; + } + + cells.seq_add(i, seq_id); + } + } + + head = 0; + } + + return true; +} + +bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) { + uint32_t v_trans; + uint32_t n_layer; + + io.read_to(&v_trans, sizeof(v_trans)); + io.read_to(&n_layer, sizeof(n_layer)); + + if (n_layer != layers.size()) { + LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size()); + return false; + } + + if (cell_count > cells.size()) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, cells.size()); + return false; + } + + if (this->v_trans != (bool) v_trans) { + LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); + return false; + } + + // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Read type of key + int32_t k_type_i_ref; + io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); + const int32_t k_type_i = (int32_t) layer.k->type; + if (k_type_i != k_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); + return false; + } + + // Read row size of key + uint64_t k_size_row_ref; + io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); + const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa); + if (k_size_row != k_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); + } + } + + if (!this->v_trans) { + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)layer.v->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read row size of value + uint64_t v_size_row_ref; + io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); + const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa); + if (v_size_row != v_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the values for the whole cell range + ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); + } + } + } else { + // For each layer, read the values for each cell (transposed) + for (const auto & layer : layers) { + const uint32_t il = layer.il; + + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)layer.v->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read element size of value + uint32_t v_size_el_ref; + io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); + const size_t v_size_el = ggml_type_size(layer.v->type); + if (v_size_el != v_size_el_ref) { + LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); + return false; + } + + // Read GQA embedding size + uint32_t n_embd_v_gqa_ref; + io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); + if (n_embd_v_gqa != n_embd_v_gqa_ref) { + LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); + return false; + } + + if (cell_count) { + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (head + j * cells.size()) * v_size_el; + ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + } + } + } + } + + return true; +} + +// +// llama_kv_cache_unified_state +// + +llama_kv_cache_unified_state::llama_kv_cache_unified_state(llama_memory_status status) : status(status) {} + +llama_kv_cache_unified_state::llama_kv_cache_unified_state( + llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) { + n_kv = kv->get_size(); + head = 0; +} + +llama_kv_cache_unified_state::llama_kv_cache_unified_state( + llama_kv_cache_unified * kv, + llama_context * lctx, + bool do_shift, + defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) { + if (!do_shift && this->dinfo.empty()) { + status = LLAMA_MEMORY_STATUS_NO_UPDATE; + } +} + +llama_kv_cache_unified_state::llama_kv_cache_unified_state( + llama_kv_cache_unified * kv, + llama_sbatch sbatch, + llama_kv_cache_unified::ubatch_heads heads, + std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sbatch(std::move(sbatch)), heads(std::move(heads)), ubatches(std::move(ubatches)) { +} + +llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default; + +bool llama_kv_cache_unified_state::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_unified_state::apply() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + // no ubatches -> this is a KV cache update + if (ubatches.empty()) { + kv->update(lctx, do_shift, dinfo); + + return true; + } + + kv->apply_ubatch(heads[i_next], ubatches[i_next]); + + n_kv = kv->get_n_kv(); + head = heads[i_next]; + + return true; +} + +std::vector & llama_kv_cache_unified_state::out_ids() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return sbatch.out_ids; +} + +llama_memory_status llama_kv_cache_unified_state::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_unified_state::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return ubatches[i_next]; +} + +uint32_t llama_kv_cache_unified_state::get_n_kv() const { + return n_kv; +} + +ggml_tensor * llama_kv_cache_unified_state::get_k(ggml_context * ctx, int32_t il) const { + return kv->get_k(ctx, il, n_kv); +} + +ggml_tensor * llama_kv_cache_unified_state::get_v(ggml_context * ctx, int32_t il) const { + return kv->get_v(ctx, il, n_kv); +} + +ggml_tensor * llama_kv_cache_unified_state::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const { + return kv->cpy_k(ctx, k_cur, il, head); +} + +ggml_tensor * llama_kv_cache_unified_state::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const { + return kv->cpy_v(ctx, v_cur, il, head); +} + +void llama_kv_cache_unified_state::set_input_k_shift(ggml_tensor * dst) const { + kv->set_input_k_shift(dst); +} + +void llama_kv_cache_unified_state::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { + kv->set_input_kq_mask(dst, ubatch, causal_attn); +} + +void llama_kv_cache_unified_state::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { + kv->set_input_pos_bucket(dst, ubatch); +} + +uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { + // the FA kernels require padding to avoid extra runtime boundary checks + return cparams.flash_attn ? 256u : 32u; +} diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h new file mode 100644 index 000000000..d96571d95 --- /dev/null +++ b/src/llama-kv-cache-unified.h @@ -0,0 +1,308 @@ +#pragma once + +#include "llama-batch.h" +#include "llama-graph.h" +#include "llama-kv-cells.h" +#include "llama-memory.h" + +#include +#include + +struct llama_cparams; +struct llama_hparams; +struct llama_model; +struct llama_context; + +// +// llama_kv_cache_unified +// + +class llama_kv_cache_unified : public llama_memory_i { +public: + static uint32_t get_padding(const llama_cparams & cparams); + + // this callback is used to filter out layers that should not be included in the cache + using layer_filter_cb = std::function; + + using ubatch_heads = std::vector; + + struct defrag_info { + bool empty() const { + return ids.empty(); + } + + // contains information about which cell moves where: + // - cell i moves to ids[i] + // - if ids[i] == i || ids[i] == ids.size(), then cell i is not moved + std::vector ids; + }; + + llama_kv_cache_unified( + const llama_model & model, + layer_filter_cb && filter, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type); + + ~llama_kv_cache_unified() = default; + + // + // llama_memory_i + // + + llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_all) override; + + llama_memory_state_ptr init_full() override; + + llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; + + bool get_can_shift() const override; + + void clear(bool data) override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + + // + // llama_kv_cache_unified specific API + // + + uint32_t get_size() const; + + bool get_has_shift() const; + + // + // graph_build API + // + + uint32_t get_n_kv() const; + + // get views of the current state of the cache + ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const; + ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const; + + // store k_cur and v_cur in the cache based on the provided head location + ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const; + ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const; + + // + // preparation API + // + + // find places for the provided ubatches in the cache, returns the head locations + // return empty vector on failure + ubatch_heads prepare(const std::vector & ubatches); + + bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo); + + // return the cell position where we can insert the ubatch + // return -1 on failure to find a contiguous slot of kv cells + int32_t find_slot(const llama_ubatch & ubatch) const; + + // emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens) + void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch); + + // + // set_input API + // + + void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; + void set_input_k_shift (ggml_tensor * dst) const; + void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + +private: + const llama_model & model; + const llama_hparams & hparams; + + struct kv_layer { + // layer index in the model + // note: can be different from the layer index in the KV cache + uint32_t il; + + ggml_tensor * k; + ggml_tensor * v; + }; + + bool v_trans = true; // the value tensor is transposed + + // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) + // note: this is not part of the KV state and it's only used to speed-up the find_slot() method + uint32_t head = 0; + + const uint32_t n_seq_max = 1; + + // required padding + const uint32_t n_pad = 1; + + // SWA + const uint32_t n_swa = 0; + + int debug = 0; + + const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; + + std::vector ctxs; + std::vector bufs; + + llama_kv_cells_unified cells; + + std::vector layers; + + // model layer id -> KV cache layer id + std::unordered_map map_layer_ids; + + // return non-empty vector if cells have been moved + defrag_info defrag_prepare(int32_t n_max_nodes) const; + + size_t total_size() const; + + size_t size_k_bytes() const; + size_t size_v_bytes() const; + + bool is_masked_swa(llama_pos p0, llama_pos p1) const; + + ggml_tensor * build_rope_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * shift, + ggml_tensor * factors, + float freq_base, + float freq_scale) const; + + llm_graph_result_ptr build_graph_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const; + + llm_graph_result_ptr build_graph_defrag( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf, + const defrag_info & dinfo) const; + + void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; + void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; + + bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); + bool state_read_data(llama_io_read_i & io, uint32_t cell_count); +}; + +class llama_kv_cache_unified_state : public llama_memory_state_i { +public: + // some shorthands + using ubatch_heads = llama_kv_cache_unified::ubatch_heads; + using defrag_info = llama_kv_cache_unified::defrag_info; + + // used for errors + llama_kv_cache_unified_state(llama_memory_status status); + + // used to create a full-cache state + llama_kv_cache_unified_state( + llama_kv_cache_unified * kv); + + // used to create an update state + llama_kv_cache_unified_state( + llama_kv_cache_unified * kv, + llama_context * lctx, + bool do_shift, + defrag_info dinfo); + + // used to create a decode state from a batch + llama_kv_cache_unified_state( + llama_kv_cache_unified * kv, + llama_sbatch sbatch, + ubatch_heads heads, + std::vector ubatches); + + virtual ~llama_kv_cache_unified_state(); + + // + // llama_memory_state_i + // + + bool next() override; + bool apply() override; + + std::vector & out_ids() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_unified_state specific API + // + + uint32_t get_n_kv() const; + + // get views of the current state of the cache + ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; + ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; + + // store k_cur and v_cur in the cache based on the provided head location + ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const; + ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const; + + void set_input_k_shift(ggml_tensor * dst) const; + + void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; + void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + +private: + llama_memory_status status; + + llama_kv_cache_unified * kv; + llama_context * lctx; + + // + // update state + // + + bool do_shift = false; + + defrag_info dinfo; + + // + // batch processing state + // + + llama_sbatch sbatch; + + // the index of the next ubatch to process + size_t i_next = 0; + + ubatch_heads heads; + + std::vector ubatches; + + // + // data needed for building the compute graph for the current ubatch: + // + + // a heuristic, to avoid attending the full cache if it is not yet utilized + // as the cache gets filled, the benefit from this heuristic disappears + int32_t n_kv; + + // the beginning of the current slot in which the ubatch will be inserted + int32_t head; +}; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp deleted file mode 100644 index dbf5f1187..000000000 --- a/src/llama-kv-cache.cpp +++ /dev/null @@ -1,1380 +0,0 @@ -#include "llama-kv-cache.h" - -#include "llama-impl.h" -#include "llama-batch.h" -#include "llama-cparams.h" -#include "llama-model.h" - -#include -#include -#include -#include -#include - -llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) { -} - -bool llama_kv_cache_unified::init( - const llama_model & model, - const llama_cparams & cparams, - ggml_type type_k, - ggml_type type_v, - uint32_t kv_size, - bool offload) { - const int32_t n_layer = hparams.n_layer; - - has_shift = false; - - recurrent = llama_model_is_recurrent(&model); - v_trans = !recurrent && !cparams.flash_attn; - can_shift = !recurrent && model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA - - LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n", - __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift); - - head = 0; - size = kv_size; - used = 0; - - this->type_k = type_k; - this->type_v = type_v; - - cells.clear(); - cells.resize(kv_size); - - // create a context for each buffer type - std::map ctx_map; - auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { - auto it = ctx_map.find(buft); - if (it == ctx_map.end()) { - ggml_init_params params = { - /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()), - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - - ggml_context * ctx = ggml_init(params); - if (!ctx) { - return nullptr; - } - - ctx_map[buft] = ctx; - ctxs.emplace_back(ctx); - - return ctx; - } - - return it->second; - }; - - k_l.reserve(n_layer); - v_l.reserve(n_layer); - - for (int i = 0; i < n_layer; i++) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); - - const char * dev_name = "CPU"; - - ggml_backend_buffer_type_t buft; - if (offload) { - auto * dev = model.dev_layer(i); - buft = ggml_backend_dev_buffer_type(dev); - - dev_name = ggml_backend_dev_name(dev); - } else { - buft = ggml_backend_cpu_buffer_type(); - } - - LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k_gqa = %d, n_embd_v_gqa = %d, dev = %s\n", __func__, - i, n_embd_k_gqa, n_embd_v_gqa, dev_name); - - ggml_context * ctx = ctx_for_buft(buft); - if (!ctx) { - LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__); - return false; - } - - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); - ggml_format_name(k, "cache_k_l%d", i); - ggml_format_name(v, "cache_v_l%d", i); - k_l.push_back(k); - v_l.push_back(v); - } - - // allocate tensors and initialize the buffers to avoid NaNs in the padding - for (auto it : ctx_map) { - auto * buft = it.first; - auto * ctx = it.second; - - ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); - if (!buf) { - LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__); - return false; - } - ggml_backend_buffer_clear(buf, 0); - LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); - bufs.emplace_back(buf); - } - - return true; -} - -int32_t llama_kv_cache_unified::get_n_tokens() const { - int32_t result = 0; - - for (uint32_t i = 0; i < size; i++) { - result += cells[i].seq_id.size(); - } - - return result; -} - -int32_t llama_kv_cache_unified::get_used_cells() const { - return used; -} - -size_t llama_kv_cache_unified::total_size() const { - size_t size = 0; - for (const auto & buf : bufs) { - size += ggml_backend_buffer_get_size(buf.get()); - } - - return size; -} - -llama_pos llama_kv_cache_unified::pos_max() const { - llama_pos pos_max = -1; - for (const auto & cell : cells) { - pos_max = std::max(pos_max, cell.pos); - } - - return pos_max; -} - -void llama_kv_cache_unified::clear() { - for (int32_t i = 0; i < (int32_t) size; ++i) { - cells[i].pos = -1; - cells[i].seq_id.clear(); - cells[i].src = -1; - cells[i].tail = -1; - } - head = 0; - used = 0; - - for (auto & buf : bufs) { - ggml_backend_buffer_clear(buf.get(), 0); - } -} - -bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - uint32_t new_head = size; - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - // models like Mamba or RWKV can't have a state partially erased - if (recurrent) { - if (seq_id >= (int64_t) size) { - // could be fatal - return false; - } - if (0 <= seq_id) { - int32_t & tail_id = cells[seq_id].tail; - if (tail_id >= 0) { - const llama_kv_cell & cell = cells[tail_id]; - // partial intersection is invalid - if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { - return false; - } - // invalidate tails which will be cleared - if (p0 <= cell.pos && cell.pos < p1) { - tail_id = -1; - } - } - } else { - // seq_id is negative, then the range should include everything or nothing - if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { - return false; - } - } - - return true; - } - - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].pos >= p0 && cells[i].pos < p1) { - if (seq_id < 0) { - cells[i].seq_id.clear(); - } else if (cells[i].has_seq_id(seq_id)) { - cells[i].seq_id.erase(seq_id); - } else { - continue; - } - if (cells[i].is_empty()) { - // keep count of the number of used cells - if (cells[i].pos >= 0) { - used--; - } - - cells[i].pos = -1; - cells[i].src = -1; - - if (new_head == size) { - new_head = i; - } - } - } - } - - // If we freed up a slot, set head to it so searching can start there. - if (new_head != size && new_head < head) { - head = new_head; - } - - return true; -} - -void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - if (seq_id_src == seq_id_dst) { - return; - } - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - if (recurrent) { - if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) { - llama_kv_cell & tail_src = cells[seq_id_src]; - llama_kv_cell & tail_dst = cells[seq_id_dst]; - if (tail_dst.tail >= 0) { - // clear destination seq_id if it wasn't empty - llama_kv_cell & cell_dst = cells[tail_dst.tail]; - - cell_dst.seq_id.erase(seq_id_dst); - tail_dst.tail = -1; - if (cell_dst.seq_id.empty()) { - cell_dst.pos = -1; - cell_dst.delta = -1; - cell_dst.src = -1; - used -= 1; - } - } - if (tail_src.tail >= 0) { - llama_kv_cell & cell_src = cells[tail_src.tail]; - - cell_src.seq_id.insert(seq_id_dst); - tail_dst.tail = tail_src.tail; - } - } - - return; - } - - // otherwise, this is the KV of a Transformer-like model - head = 0; - - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].has_seq_id(seq_id_src) && cells[i].pos >= p0 && cells[i].pos < p1) { - cells[i].seq_id.insert(seq_id_dst); - } - } -} - -void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { - uint32_t new_head = size; - - for (uint32_t i = 0; i < size; ++i) { - if (recurrent && (llama_seq_id) i != seq_id) { - cells[i].tail = -1; - } - - if (!cells[i].has_seq_id(seq_id)) { - if (cells[i].pos >= 0) { - used--; - } - - cells[i].pos = -1; - cells[i].src = -1; - cells[i].seq_id.clear(); - - if (new_head == size){ - new_head = i; - } - } else { - cells[i].seq_id.clear(); - cells[i].seq_id.insert(seq_id); - } - } - - // If we freed up a slot, set head to it so searching can start there. - if (new_head != size && new_head < head) { - head = new_head; - } -} - -void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { - if (delta == 0) { - return; - } - - uint32_t new_head = size; - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - // If there is no range then return early to avoid looping over the - if (p0 == p1) { - return; - } - - if (recurrent) { - // for Mamba-like or RWKV models, only the pos needs to be shifted - if (0 <= seq_id && seq_id < (int64_t) size) { - const int32_t tail_id = cells[seq_id].tail; - if (tail_id >= 0) { - llama_kv_cell & cell = cells[tail_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos += delta; - } - } - } - return; - } - - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) { - has_shift = true; - cells[i].pos += delta; - cells[i].delta += delta; - - if (cells[i].pos < 0) { - if (!cells[i].is_empty()) { - used--; - } - cells[i].pos = -1; - cells[i].seq_id.clear(); - if (new_head == size) { - new_head = i; - } - } - } - } - - // If we freed up a slot, set head to it so searching can start there. - // Otherwise we just start the next search from the beginning. - head = new_head != size ? new_head : 0; -} - -void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { - if (d == 1) { - return; - } - - if (p0 < 0) { - p0 = 0; - } - - if (p1 < 0) { - p1 = std::numeric_limits::max(); - } - - // If there is no range then return early to avoid looping over the cache. - if (p0 == p1) { - return; - } - - if (recurrent) { - // for Mamba-like or RWKV models, only the pos needs to be changed - if (0 <= seq_id && seq_id < (int64_t) size) { - const int32_t tail_id = cells[seq_id].tail; - if (tail_id >= 0) { - llama_kv_cell & cell = cells[tail_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos /= d; - } - } - } - - return; - } - - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) { - has_shift = true; - - { - llama_pos p_old = cells[i].pos; - cells[i].pos /= d; - cells[i].delta += cells[i].pos - p_old; - } - } - } -} - -llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { - llama_pos result = 0; - - for (uint32_t i = 0; i < size; ++i) { - if (cells[i].has_seq_id(seq_id)) { - result = std::max(result, cells[i].pos); - } - } - - return result; -} - -void llama_kv_cache_unified::defrag() { - if (!recurrent) { - do_defrag = true; - } -} - -void llama_kv_cache_unified::restore() { - if (pending.ranges.empty()) { - return; - } - - // TODO: tmp - move to llama_kv_cache_recurrent - if (recurrent) { - seq_rm(-1, -1, -1); - return; - } - - uint32_t new_head = size; - - for (auto & range : pending.ranges) { - for (uint32_t i = range.c0; i < range.c1; ++i) { - cells[i].seq_id.clear(); - - // keep count of the number of used cells - if (cells[i].pos >= 0) { - used--; - } - - cells[i].pos = -1; - cells[i].src = -1; - } - - new_head = std::min(new_head, range.c0); - } - - if (new_head != size && new_head < head) { - head = new_head; - } -} - -void llama_kv_cache_unified::commit() { - // TODO: tmp - move to llama_kv_cache_recurrent - if (recurrent) { - return; - } - - if (pending.ranges.empty()) { - LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n", - __func__, "https://github.com/ggml-org/llama.cpp/pull/12695"); - return; - } - - pending.ranges.clear(); -} - -bool llama_kv_cache_unified::get_can_shift() const { - return can_shift; -} - -bool llama_kv_cache_unified::find_slot( - const llama_ubatch & ubatch) { - const uint32_t n_tokens = ubatch.n_tokens; - const uint32_t n_seqs = ubatch.n_seqs; - const uint32_t n_seq_tokens = ubatch.n_seq_tokens; - - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (head > used + 2*ubatch.n_tokens) { - head = 0; - } - - if (recurrent) { - // For recurrent state architectures (like Mamba or RWKV), - // each cache cell can store the state for a whole sequence. - // A slot should be always be contiguous. - - // can only process batches with an equal number of new tokens in each sequence - GGML_ASSERT(ubatch.equal_seqs); - - int32_t min = size - 1; - int32_t max = 0; - - // everything should fit if all seq_ids are smaller than the max - for (uint32_t s = 0; s < n_seqs; ++s) { - const uint32_t n_seq_id = ubatch.n_seq_id[s]; - for (uint32_t j = 0; j < n_seq_id; ++j) { - const llama_seq_id seq_id = ubatch.seq_id[s][j]; - - if (seq_id < 0 || (uint32_t) seq_id >= size) { - // too big seq_id - // TODO: would it be possible to resize the cache instead? - LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size); - return false; - } - if (j > 0) { - llama_kv_cell & seq = cells[seq_id]; - if (seq.tail >= 0) { - llama_kv_cell & cell = cells[seq.tail]; - // clear cells from seq_ids that become shared - // (should not normally happen, but let's handle it anyway) - cell.seq_id.erase(seq_id); - seq.tail = -1; - if (cell.seq_id.empty()) { - cell.pos = -1; - cell.src = -1; - used -= 1; - } - } - } - } - } - -#ifndef NDEBUG - { - std::vector tails_verif; - tails_verif.assign(size, -1); - for (uint32_t i = 0; i < size; ++i) { - llama_kv_cell & cell = cells[i]; - for (llama_seq_id seq_id : cell.seq_id) { - if (tails_verif[seq_id] != -1) { - LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]); - } - tails_verif[seq_id] = i; - } - } - for (uint32_t i = 0; i < size; ++i) { - if (tails_verif[i] != cells[i].tail) { - LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]); - } - } - } -#endif - - // find next empty cell - uint32_t next_empty_cell = head; - - for (uint32_t i = 0; i < size; ++i) { - if (next_empty_cell >= size) { next_empty_cell -= size; } - llama_kv_cell & cell = cells[next_empty_cell]; - if (cell.is_empty()) { break; } - next_empty_cell += 1; - } - - // find usable cell range - for (uint32_t s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch.seq_id[s][0]; - llama_kv_cell & seq_meta = cells[seq_id]; - bool has_cell = false; - if (seq_meta.tail >= 0) { - llama_kv_cell & cell = cells[seq_meta.tail]; - GGML_ASSERT(cell.has_seq_id(seq_id)); - // does this seq_id "own" the cell? - if (cell.seq_id.size() == 1) { has_cell = true; } - } - if (!has_cell) { - llama_kv_cell & empty_cell = cells[next_empty_cell]; - GGML_ASSERT(empty_cell.is_empty()); - // copy old tail into the empty cell - if (seq_meta.tail >= 0) { - llama_kv_cell & orig_cell = cells[seq_meta.tail]; - empty_cell.pos = orig_cell.pos; - empty_cell.src = orig_cell.src; - orig_cell.seq_id.erase(seq_id); - empty_cell.seq_id.insert(seq_id); // will be overwritten - } - seq_meta.tail = next_empty_cell; - // find next empty cell - if (s + 1 < n_seqs) { - next_empty_cell += 1; - for (uint32_t i = 0; i < size; ++i) { - if (next_empty_cell >= size) { next_empty_cell -= size; } - llama_kv_cell & cell = cells[next_empty_cell]; - if (cell.is_empty()) { break; } - next_empty_cell += 1; - } - } - } - if (min > seq_meta.tail) { min = seq_meta.tail; } - if (max < seq_meta.tail) { max = seq_meta.tail; } - } - - // gather and re-order - for (uint32_t s = 0; s < n_seqs; ++s) { - int32_t dst_id = s + min; - int32_t src_id = cells[ubatch.seq_id[s][0]].tail; - if (dst_id != src_id) { - llama_kv_cell & dst_cell = cells[dst_id]; - llama_kv_cell & src_cell = cells[src_id]; - - std::swap(dst_cell.pos, src_cell.pos); - std::swap(dst_cell.src, src_cell.src); - std::swap(dst_cell.seq_id, src_cell.seq_id); - - // swap tails (assuming they NEVER overlap) - for (const llama_seq_id seq_id : src_cell.seq_id) { - cells[seq_id].tail = src_id; - } - for (const llama_seq_id seq_id : dst_cell.seq_id) { - cells[seq_id].tail = dst_id; - } - } - } - - // update the pos of the used seqs - for (uint32_t s = 0; s < n_seqs; ++s) { - const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; - int32_t cell_id = s + min; - llama_kv_cell & cell = cells[cell_id]; - - if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { - // What should happen when the pos backtracks or skips a value? - // Clearing the state mid-batch would require special-casing which isn't done. - LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", - __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens); - } - cell.pos = last_pos; - cell.seq_id.clear(); - for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) { - const llama_seq_id seq_id = ubatch.seq_id[s][j]; - cell.seq_id.insert(seq_id); - cells[seq_id].tail = cell_id; - } - } - - // allow getting the range of used cells, from head to head + n - head = min; - n = max - min + 1; - used = std::count_if(cells.begin(), cells.end(), - [](const llama_kv_cell& cell){ return !cell.is_empty(); }); - - // sanity check - return n >= n_seqs; - } - - // otherwise, one cell per token. - - if (n_tokens > size) { - LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size); - return false; - } - - uint32_t n_tested = 0; - - while (true) { - if (head + n_tokens > size) { - n_tested += size - head; - head = 0; - continue; - } - - bool found = true; - for (uint32_t i = 0; i < n_tokens; i++) { - if (cells[head + i].pos >= 0) { - found = false; - head += i + 1; - n_tested += i + 1; - break; - } - } - - if (found) { - break; - } - - if (n_tested >= size) { - //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); - return false; - } - } - - for (uint32_t s = 0; s < n_seqs; s++) { - for (uint32_t i = 0; i < n_seq_tokens; ++i) { - uint32_t k = s*n_seq_tokens + i; - cells[head + k].pos = ubatch.pos[k]; - - for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) { - cells[head + k].seq_id.insert(ubatch.seq_id[s][j]); - } - } - } - - used += n_tokens; - - pending.ranges.push_back({head, head + n_tokens}); - - return true; -} - -uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const { - // the FA kernels require padding to avoid extra runtime boundary checks - return cparams.flash_attn ? 256u : 32u; -} - -uint32_t llama_kv_cache_unified::cell_max() const { - for (uint32_t i = size; i > 0; --i) { - const llama_kv_cell & cell = cells[i - 1]; - - if (cell.pos >= 0 && !cell.is_empty()) { - return i; - } - } - - return 0; -} - -size_t llama_kv_cache_unified::size_k_bytes() const { - size_t size_k_bytes = 0; - - for (const auto & k : k_l) { - size_k_bytes += ggml_nbytes(k); - } - - return size_k_bytes; -} - -size_t llama_kv_cache_unified::size_v_bytes() const { - size_t size_v_bytes = 0; - - for (const auto & v : v_l) { - size_v_bytes += ggml_nbytes(v); - } - - return size_v_bytes; -} - -bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { - const uint32_t n_layer = hparams.n_layer; - - const uint32_t n_kv = cell_max(); - const uint32_t n_used = used; - - assert(n_used <= n_kv); - - //const int64_t t_start = ggml_time_us(); - - // number of cells moved - uint32_t n_moves = 0; - - // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag) - // - source view, destination view, copy operation - // - x2 for keys and values - //const uint32_t max_moves = max_nodes()/(6*n_layer); - // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 - const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer); - - // determine which KV cells to move where - // - // cell i moves to ids[i] - // - // if ids[i] == i || ids[i] == n_kv, then cell i is not moved - // - auto & ids = defrag_info.ids; - - ids.clear(); - ids.resize(n_kv, n_kv); - - for (uint32_t i0 = 0; i0 < n_used; ++i0) { - const auto & cell0 = cells[i0]; - - if (!cell0.is_empty()) { - ids[i0] = i0; - - continue; - } - - // found a hole - fill it with data from the end of the cache - - uint32_t nh = 1; - - // determine the size of the hole - while (i0 + nh < n_used && cells[i0 + nh].is_empty()) { - nh++; - } - - uint32_t nf = 0; - uint32_t is = n_kv - 1; - - // starting from the end, find nh non-empty cells - for (; is > i0; --is) { - const auto & cell1 = cells[is]; - - if (cell1.is_empty() || ids[is] != n_kv) { - continue; - } - - // non-empty cell which is not yet moved - nf++; - - if (nf == nh) { - break; - } - } - - // this can only happen if `n_used` is not accurate, which would be a bug - GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh"); - - nf = 0; - - uint32_t i1 = is; - - // are we moving a continuous block of memory? - bool cont = false; - - // should we stop searching for the next move? - bool stop = false; - - // go back and move the nf cells to the hole - for (; i1 < n_kv; ++i1) { - auto & cell1 = cells[i1]; - - if (cell1.is_empty() || ids[i1] != n_kv) { - if (n_moves == max_moves) { - stop = true; - break; - } - - cont = false; - continue; - } - - // this cell goes to (i0 + nf) - ids[i1] = i0 + nf; - - // move the cell meta data - cells[i0 + nf] = cell1; - - // clear the old cell and move the head there - cell1 = llama_kv_cell(); - head = n_used; - - if (!cont) { - n_moves++; - cont = true; - } - - nf++; - - if (nf == nh) { - break; - } - } - - if (stop || n_moves == max_moves) { - break; - } - - //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh); - - i0 += nh - 1; - } - - if (n_moves == 0) { - return false; - } - - LLAMA_LOG_DEBUG("(tmp log) KV defrag cell moves: %u\n", n_moves); - - LLAMA_LOG_DEBUG("expected gf nodes: %u\n", 6*n_moves*n_layer); - - return true; -} - -void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { - std::vector> cell_ranges; // ranges, from inclusive, to exclusive - uint32_t cell_count = 0; - - // Count the number of cells with the specified seq_id - // Find all the ranges of cells with this seq id (or all, when -1) - uint32_t cell_range_begin = size; - for (uint32_t i = 0; i < size; ++i) { - const auto & cell = cells[i]; - if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { - ++cell_count; - if (cell_range_begin == size) { - cell_range_begin = i; - } - } else { - if (cell_range_begin != size) { - cell_ranges.emplace_back(cell_range_begin, i); - cell_range_begin = size; - } - } - } - if (cell_range_begin != size) { - cell_ranges.emplace_back(cell_range_begin, size); - } - - // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count - uint32_t cell_count_check = 0; - for (const auto & range : cell_ranges) { - cell_count_check += range.second - range.first; - } - GGML_ASSERT(cell_count == cell_count_check); - - io.write(&cell_count, sizeof(cell_count)); - - state_write_meta(io, cell_ranges, seq_id); - state_write_data(io, cell_ranges); -} - -void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) { - uint32_t cell_count; - io.read_to(&cell_count, sizeof(cell_count)); - - bool res = true; - res = res && state_read_meta(io, cell_count, seq_id); - res = res && state_read_data(io, cell_count); - - if (!res) { - if (seq_id == -1) { - clear(); - } else { - seq_rm(seq_id, -1, -1); - } - throw std::runtime_error("failed to restore kv cache"); - } -} - -void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { - for (const auto & range : cell_ranges) { - for (uint32_t i = range.first; i < range.second; ++i) { - const auto & cell = cells[i]; - const llama_pos pos = cell.pos; - const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0; - - io.write(&pos, sizeof(pos)); - io.write(&n_seq_id, sizeof(n_seq_id)); - - if (n_seq_id) { - for (auto seq_id : cell.seq_id) { - io.write(&seq_id, sizeof(seq_id)); - } - } - } - } -} - -void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { - const uint32_t v_trans = this->v_trans ? 1 : 0; - const uint32_t n_layer = hparams.n_layer; - - io.write(&v_trans, sizeof(v_trans)); - io.write(&n_layer, sizeof(n_layer)); - - std::vector tmp_buf; - - // Iterate and write all the keys first, each row is a cell - // Get whole range at a time - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); - - // Write key type - const int32_t k_type_i = (int32_t)k_l[il]->type; - io.write(&k_type_i, sizeof(k_type_i)); - - // Write row size of key - const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); - io.write(&k_size_row, sizeof(k_size_row)); - - // Read each range of cells of k_size length each into tmp_buf and write out - for (const auto & range : cell_ranges) { - const size_t range_size = range.second - range.first; - const size_t buf_size = range_size * k_size_row; - io.write_tensor(k_l[il], range.first * k_size_row, buf_size); - } - } - - if (!v_trans) { - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Write value type - const int32_t v_type_i = (int32_t)v_l[il]->type; - io.write(&v_type_i, sizeof(v_type_i)); - - // Write row size of value - const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); - io.write(&v_size_row, sizeof(v_size_row)); - - // Read each range of cells of v_size length each into tmp_buf and write out - for (const auto & range : cell_ranges) { - const size_t range_size = range.second - range.first; - const size_t buf_size = range_size * v_size_row; - io.write_tensor(v_l[il], range.first * v_size_row, buf_size); - } - } - } else { - // When v is transposed, we also need the element size and get the element ranges from each row - const uint32_t kv_size = size; - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Write value type - const int32_t v_type_i = (int32_t)v_l[il]->type; - io.write(&v_type_i, sizeof(v_type_i)); - - // Write element size - const uint32_t v_size_el = ggml_type_size(v_l[il]->type); - io.write(&v_size_el, sizeof(v_size_el)); - - // Write GQA embedding size - io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); - - // For each row, we get the element values of each cell - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - // Read each range of cells of v_size_el length each into tmp_buf and write out - for (const auto & range : cell_ranges) { - const size_t range_size = range.second - range.first; - const size_t src_offset = (range.first + j * kv_size) * v_size_el; - const size_t buf_size = range_size * v_size_el; - io.write_tensor(v_l[il], src_offset, buf_size); - } - } - } - } -} - -bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { - if (dest_seq_id != -1) { - // single sequence - - seq_rm(dest_seq_id, -1, -1); - - llama_sbatch sbatch; - llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false); - - batch.n_tokens = cell_count; - batch.n_seq_tokens = cell_count; - batch.n_seqs = 1; - - for (uint32_t i = 0; i < cell_count; ++i) { - llama_pos pos; - uint32_t n_seq_id; - - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); - - if (n_seq_id != 0) { - LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); - return false; - } - - batch.pos[i] = pos; - } - batch.n_seq_id[0] = 1; - batch.seq_id[0] = &dest_seq_id; - if (!find_slot(batch)) { - LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); - return false; - } - commit(); - - // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) - // Assume that this is one contiguous block of cells - GGML_ASSERT(head + cell_count <= size); - GGML_ASSERT(cells[head].pos == batch.pos[0]); - GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]); - GGML_ASSERT(cells[head].has_seq_id(dest_seq_id)); - GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id)); - } else { - // whole KV cache restore - - if (cell_count > size) { - LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); - return false; - } - - clear(); - - for (uint32_t i = 0; i < cell_count; ++i) { - llama_kv_cell & cell = cells[i]; - - llama_pos pos; - uint32_t n_seq_id; - - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); - - cell.pos = pos; - - for (uint32_t j = 0; j < n_seq_id; ++j) { - llama_seq_id seq_id; - io.read_to(&seq_id, sizeof(seq_id)); - - // TODO: llama_kv_cache_unified should have a notion of max sequences - //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { - if (seq_id < 0) { - //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); - LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id); - return false; - } - - cell.seq_id.insert(seq_id); - - if (recurrent) { - int32_t & tail = cells[seq_id].tail; - if (tail != -1) { - LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail); - return false; - } - tail = i; - } - } - } - - head = 0; - used = cell_count; - } - - if (recurrent) { - for (uint32_t i = 0; i < cell_count; ++i) { - uint32_t cell_id = head + i; - // make sure the recurrent states will keep their restored state - cells[cell_id].src = cell_id; - } - } - - return true; -} - -bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) { - uint32_t v_trans; - uint32_t n_layer; - io.read_to(&v_trans, sizeof(v_trans)); - io.read_to(&n_layer, sizeof(n_layer)); - - if (n_layer != hparams.n_layer) { - LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); - return false; - } - if (cell_count > size) { - LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size); - return false; - } - if (v_trans != (bool) v_trans) { - LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); - return false; - } - - // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); - - // Read type of key - int32_t k_type_i_ref; - io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); - const int32_t k_type_i = (int32_t) k_l[il]->type; - if (k_type_i != k_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); - return false; - } - - // Read row size of key - uint64_t k_size_row_ref; - io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); - const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); - if (k_size_row != k_size_row_ref) { - LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); - return false; - } - - if (cell_count) { - // Read and set the keys for the whole cell range - ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); - } - } - - if (!v_trans) { - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Read type of value - int32_t v_type_i_ref; - io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)v_l[il]->type; - if (v_type_i != v_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); - return false; - } - - // Read row size of value - uint64_t v_size_row_ref; - io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); - const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); - if (v_size_row != v_size_row_ref) { - LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); - return false; - } - - if (cell_count) { - // Read and set the values for the whole cell range - ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); - } - } - } else { - // For each layer, read the values for each cell (transposed) - for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); - - // Read type of value - int32_t v_type_i_ref; - io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)v_l[il]->type; - if (v_type_i != v_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); - return false; - } - - // Read element size of value - uint32_t v_size_el_ref; - io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); - const size_t v_size_el = ggml_type_size(v_l[il]->type); - if (v_size_el != v_size_el_ref) { - LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); - return false; - } - - // Read GQA embedding size - uint32_t n_embd_v_gqa_ref; - io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); - if (n_embd_v_gqa != n_embd_v_gqa_ref) { - LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); - return false; - } - - if (cell_count) { - // For each row in the transposed matrix, read the values for the whole cell range - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const size_t dst_offset = (head + j * size) * v_size_el; - ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); - } - } - } - } - - return true; -} - -// -// kv cache view -// - -llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max) { - llama_kv_cache_view result = { - /*.n_cells = */ 0, - /*.n_seq_max = */ n_seq_max, - /*.token_count = */ 0, - /*.used_cells = */ kv.get_used_cells(), - /*.max_contiguous = */ 0, - /*.max_contiguous_idx = */ -1, - /*.cells = */ nullptr, - /*.cells_sequences = */ nullptr, - }; - - return result; -} - -void llama_kv_cache_view_free(llama_kv_cache_view * view) { - if (view->cells != nullptr) { - free(view->cells); - view->cells = nullptr; - } - if (view->cells_sequences != nullptr) { - free(view->cells_sequences); - view->cells_sequences = nullptr; - } -} - -void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv) { - // TODO: rework this in the future, for now quick hack - const llama_kv_cache_unified * kvu = dynamic_cast(kv); - if (kvu == nullptr) { - LLAMA_LOG_ERROR("%s: the kv_cache_view currently works only with llama_kv_cache_unified\n", __func__); - return; - } - - if (uint32_t(view->n_cells) < kvu->size || view->cells == nullptr) { - view->n_cells = int32_t(kvu->size); - void * p = realloc(view->cells, sizeof(llama_kv_cache_view_cell) * view->n_cells); - GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells"); - view->cells = (llama_kv_cache_view_cell *)p; - p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_seq_max * view->n_cells); - GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences"); - view->cells_sequences = (llama_seq_id *)p; - } - - const std::vector & kv_cells = kvu->cells; - llama_kv_cache_view_cell * c_curr = view->cells; - llama_seq_id * cs_curr = view->cells_sequences; - int32_t used_cells = 0; - int32_t token_count = 0; - int32_t curr_contig_idx = -1; - uint32_t max_contig = 0; - int32_t max_contig_idx = -1; - - for (int32_t i = 0; i < int32_t(kvu->size); i++, c_curr++, cs_curr += view->n_seq_max) { - const size_t curr_size = kv_cells[i].seq_id.size(); - token_count += curr_size; - c_curr->pos = kv_cells[i].pos + kv_cells[i].delta; - - if (curr_size > 0) { - if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) { - max_contig = i - curr_contig_idx; - max_contig_idx = curr_contig_idx; - } - curr_contig_idx = -1; - } else if (curr_contig_idx < 0) { - curr_contig_idx = i; - } - - int seq_idx = 0; - for (const llama_seq_id it : kv_cells[i].seq_id) { - if (seq_idx >= view->n_seq_max) { - break; - } - cs_curr[seq_idx] = it; - seq_idx++; - } - if (seq_idx != 0) { - used_cells++; - } - for (; seq_idx < view->n_seq_max; seq_idx++) { - cs_curr[seq_idx] = -1; - } - } - if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) { - max_contig_idx = curr_contig_idx; - max_contig = kv_cells.size() - curr_contig_idx; - } - view->max_contiguous = max_contig; - view->max_contiguous_idx = max_contig_idx; - view->token_count = token_count; - view->used_cells = used_cells; - if (uint32_t(used_cells) != kvu->used) { - LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n", - __func__, kvu->used, used_cells); - } -} diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h deleted file mode 100644 index 56c74035a..000000000 --- a/src/llama-kv-cache.h +++ /dev/null @@ -1,213 +0,0 @@ -#pragma once - -#include "llama.h" -#include "llama-io.h" -#include "llama-memory.h" - -#include "ggml-cpp.h" - -#include -#include -#include - -struct llama_cparams; -struct llama_hparams; -struct llama_ubatch; - -struct llama_kv_cache : public llama_memory_i { - using llama_memory_i::llama_memory_i; - - virtual void restore() = 0; // call if batch processing fails - restores the cache state - virtual void commit() = 0; // call after successful batch processing - clears any pending state - - virtual int32_t get_n_tokens() const = 0; - virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache - - virtual bool get_can_shift() const = 0; - - bool get_can_edit() const override { return get_can_shift(); } -}; - -struct llama_kv_cache_guard { - llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {} - - ~llama_kv_cache_guard() { - kv->restore(); - } - - void commit() { - kv->commit(); - } - -private: - llama_kv_cache * kv; -}; - -struct llama_kv_cell { - llama_pos pos = -1; - llama_pos delta = 0; - int32_t src = -1; // used by recurrent state models to copy states - int32_t tail = -1; - - std::set seq_id; - - bool has_seq_id(const llama_seq_id & id) const { - return seq_id.find(id) != seq_id.end(); - } - - bool is_empty() const { - return seq_id.empty(); - } - - bool is_same_seq(const llama_kv_cell & other) const { - return seq_id == other.seq_id; - } -}; - -// ring-buffer of cached KV data -// TODO: pimpl -// TODO: add notion of max sequences -class llama_kv_cache_unified : public llama_kv_cache { -public: - // can be used to query data from the model if needed - struct callbacks { - std::function get_rope_factors; - }; - - llama_kv_cache_unified( - const llama_hparams & hparams, - callbacks cbs); - - virtual ~llama_kv_cache_unified() = default; - - // TODO: become constructor - bool init( - const llama_model & model, // TODO: do not reference the model - const llama_cparams & cparams, - ggml_type type_k, - ggml_type type_v, - uint32_t kv_size, - bool offload); - - int32_t get_n_tokens() const override; - int32_t get_used_cells() const override; - - size_t total_size() const; - - // TODO: better data structures to reduce the cost of this operation - llama_pos pos_max() const; - - void clear() override; - void defrag() override; - - virtual void restore() override; - virtual void commit() override; - - bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; - void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; - void seq_keep(llama_seq_id seq_id) override; - void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override; - void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; - - llama_pos seq_pos_max(llama_seq_id seq_id) const override; - - bool get_can_shift() const override; - - // find an empty slot of size "n_tokens" in the cache - // updates the cache head - // Note: On success, it's important that cache.head points - // to the first cell of the slot. - bool find_slot(const llama_ubatch & batch); - - // TODO: maybe not needed - uint32_t get_padding(const llama_cparams & cparams) const; - - // find how many cells are currently in use - uint32_t cell_max() const; - - size_t size_k_bytes() const; - size_t size_v_bytes() const; - - // defrag - - struct { - std::vector ids; - } defrag_info; - - // return true if cells have been moved - bool defrag_prepare(int32_t n_max_nodes); - - // commit/restore cache - - struct slot_range { - uint32_t c0 = 0; // note: these are cell indices, not sequence positions - uint32_t c1 = 0; - }; - - // pending cell updates that are not yet committed - struct { - std::vector ranges; - } pending; - - // state write/load - - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1); - - // members - - const llama_hparams & hparams; - - callbacks cbs; - - bool has_shift = false; - bool do_defrag = false; - - // TODO: remove this and implement llama_kv_cache_recurrent instead - bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token - - bool v_trans = true; // the value tensor is transposed - bool can_shift = false; - - // Note: The value of head isn't only used to optimize searching - // for a free KV slot. llama_decode_impl also uses it, so it - // cannot be freely changed after a slot has been allocated. - uint32_t head = 0; - uint32_t size = 0; - uint32_t used = 0; // used cells (i.e. at least one seq_id) - - // computed before each graph build - uint32_t n = 0; - - std::vector cells; - - std::vector k_l; // per layer - std::vector v_l; - -private: - ggml_type type_k = GGML_TYPE_F16; - ggml_type type_v = GGML_TYPE_F16; - - std::vector ctxs; - std::vector bufs; - - void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; - void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; - - bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); - bool state_read_data(llama_io_read_i & io, uint32_t cell_count); -}; - -// TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified -//class llama_kv_cache_recurrent : public llama_kv_cache_unified { -//public: -// using llama_kv_cache_unified::llama_kv_cache_unified; -//}; - -// -// kv cache view -// - -llama_kv_cache_view llama_kv_cache_view_init(const llama_kv_cache & kv, int32_t n_seq_max); - -void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache * kv); diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h new file mode 100644 index 000000000..1d4e70f4d --- /dev/null +++ b/src/llama-kv-cells.h @@ -0,0 +1,415 @@ +#pragma once + +#include "llama.h" +#include "llama-cparams.h" + +#include +#include +#include +#include + +// meta information about KV cells that can be part of multiple sequences at the same time +// TODO: add unit tests +class llama_kv_cells_unified { +public: + void reset() { + for (uint32_t i = 0; i < pos.size(); ++i) { + pos[i] = -1; + shift[i] = 0; + seq[i].reset(); + } + + has_shift = false; + + used.clear(); + + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + seq_pos[s].clear(); + } + } + + void reset_shift() { + has_shift = false; + + for (uint32_t i = 0; i < shift.size(); ++i) { + shift[i] = 0; + } + } + + uint32_t size() const { + return pos.size(); + } + + void resize(uint32_t n) { + pos.resize(n); + shift.resize(n); + seq.resize(n); + + reset(); + } + + bool is_empty(uint32_t i) const { + assert(i < pos.size()); + assert((pos[i] < 0 && pos[i] == -1) || pos[i] >= 0); + + return pos[i] == -1; + } + + uint32_t get_used() const { + return used.size(); + } + + // the index of the first cell that is used + // return 0 if no cells are used + uint32_t used_min() const { + return used.empty() ? 0 : *used.begin(); + } + + // the index of the last cell that is used + 1 + // return 0 if no cells are used + uint32_t used_max_p1() const { + return used.empty() ? 0 : *used.rbegin() + 1; + } + + bool get_has_shift() const { + return has_shift; + } + + // move cell isrc to idst (used during defrag) + void mv(uint32_t isrc, uint32_t idst) { + assert(isrc < pos.size()); + assert(idst < pos.size()); + + assert(pos[idst] == -1); + assert(pos[isrc] != -1); + + pos [idst] = pos [isrc]; + shift[idst] = shift[isrc]; + seq [idst] = seq [isrc]; + + pos [isrc] = -1; + shift[isrc] = 0; + seq [isrc].reset(); + + used.erase (isrc); + used.insert(idst); + } + + // copy the state of cells [i, i + n) (used for save/restore the state of the cells) + llama_kv_cells_unified cp(uint32_t i, uint32_t n) const { + assert(i + n <= pos.size()); + + llama_kv_cells_unified res; + + res.resize(n); + + for (uint32_t j = 0; j < n; ++j) { + res.pos[j] = pos[i + j]; + res.seq[j] = seq[i + j]; + + assert(shift[i + j] == 0); + } + + return res; + } + + // set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells) + void set(uint32_t i, const llama_kv_cells_unified & other) { + assert(i + other.pos.size() <= pos.size()); + + for (uint32_t j = 0; j < other.pos.size(); ++j) { + if (pos[i + j] == -1 && other.pos[j] != -1) { + used.insert(i + j); + } + + if (pos[i + j] != -1 && other.pos[j] == -1) { + used.erase(i + j); + } + + if (pos[i + j] != -1) { + seq_pos_rm(i + j); + } + + pos[i + j] = other.pos[j]; + seq[i + j] = other.seq[j]; + + if (pos[i + j] != -1) { + seq_pos_add(i + j); + } + + assert(shift[i + j] == 0); + } + } + + // clear a non-empty cell + void rm(uint32_t i) { + assert(i < pos.size()); + assert(pos[i] != -1); + + seq_pos_rm(i); + seq[i].reset(); + + pos[i] = -1; + shift[i] = 0; + + used.erase(i); + } + + // note: call only if the cell has seq_id + // return true if the cell becomes empty + bool seq_rm(uint32_t i, llama_seq_id seq_id) { + assert(i < pos.size()); + assert(seq[i].test(seq_id)); + assert(pos[i] != -1); + assert(seq_id >= 0); + + seq[i].reset(seq_id); + seq_pos[seq_id].erase(pos[i]); + + if (seq[i].none()) { + pos[i] = -1; + shift[i] = 0; + + used.erase(i); + + return true; + } + + return false; + } + + // return true if the cell becomes empty (i.e. it did not contain seq_id before the call) + bool seq_keep(uint32_t i, llama_seq_id seq_id) { + assert(i < pos.size()); + + if (seq[i].test(seq_id)) { + seq_pos_rm(i); + seq[i].reset(); + + seq[i].set(seq_id); + seq_pos[seq_id].insert(pos[i]); + + return false; + } + + if (seq[i].any()) { + seq_pos_rm(i); + seq[i].reset(); + + pos[i] = -1; + shift[i] = 0; + + used.erase(i); + + return true; + } + + assert(pos[i] == -1); + + return false; + } + + // number of different sequences in the cell + int seq_count(uint32_t i) const { + assert(i < pos.size()); + assert(pos[i] != -1); + + return seq[i].count(); + } + + // check if the cell contains seq_id + bool seq_has(uint32_t i, llama_seq_id seq_id) const { + assert(i < pos.size()); + assert(seq_id >= 0); + + return seq[i].test(seq_id); + } + + // note: call only if the cell is not empty and the seq_id is not in the cell + void seq_add(uint32_t i, llama_seq_id seq_id) { + assert(i < pos.size()); + assert(pos[i] != -1); + assert(!seq[i].test(seq_id)); + + seq[i].set(seq_id); + seq_pos[seq_id].insert(pos[i]); + } + + // return the sequence id of this cell + // note: call only for cells with exactly one sequence + llama_seq_id seq_get(uint32_t i) const { + assert(seq[i].count() == 1); + + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + if (seq[i].test(s)) { + return s; + } + } + + return -1; + } + + // the minimum position of sequence seq_id currently present in any of the cells + // return -1 if the sequence is not present + llama_pos seq_pos_min(llama_seq_id seq_id) const { + assert(seq_id >= 0); + assert(seq_id < LLAMA_MAX_SEQ); + + if (seq_pos[seq_id].empty()) { + return -1; + } + + return *seq_pos[seq_id].begin(); + } + + // the maximum position of sequence seq_id currently present in any of the cells + // return -1 if the sequence is not present + llama_pos seq_pos_max(llama_seq_id seq_id) const { + assert(seq_id >= 0); + assert(seq_id < LLAMA_MAX_SEQ); + + if (seq_pos[seq_id].empty()) { + return -1; + } + + return *seq_pos[seq_id].rbegin(); + } + + // note: call only if the cell is not empty + llama_pos pos_get(uint32_t i) const { + assert(i < pos.size()); + assert(pos[i] != -1); + + return pos[i]; + } + + // note: call only if the cell is not empty + llama_pos get_shift(uint32_t i) const { + assert(i < pos.size()); + assert(pos[i] != -1); + + return shift[i]; + } + + // check if a cell is not empty and its position is within [p0, p1) + bool pos_in(uint32_t i, llama_pos p0, llama_pos p1) const { + assert(i < pos.size()); + + return pos[i] >= p0 && pos[i] < p1; + } + + // set the position of an empty cell + // does not modify "has_shift" + // note: call only if the cell is empty + void pos_set(uint32_t i, llama_pos p) { + assert(i < pos.size()); + assert(pos[i] == -1); + assert(seq[i].none()); + + pos[i] = p; + + used.insert(i); + } + + // pos[i] = pos[i] + d + // sets "has_shift" to true + // note: call only if the cell is not empty + bool pos_add(uint32_t i, llama_pos d) { + assert(i < pos.size()); + assert(pos[i] != -1); + + seq_pos_rm(i); + + pos[i] += d; + shift[i] += d; + + has_shift = true; + + if (pos[i] < 0) { + seq[i].reset(); + pos[i] = -1; + shift[i] = 0; + + used.erase(i); + + return true; + } + + seq_pos_add(i); + + return false; + } + + // pos[i] = pos[i] / d + // sets "has_shift" to true + // note: call only if the cell is not empty + void pos_div(uint32_t i, int d) { + assert(i < pos.size()); + assert(pos[i] != -1); + + const llama_pos p_old = pos[i]; + + seq_pos_rm(i); + + pos[i] /= d; + shift[i] += p_old - pos[i]; + + seq_pos_add(i); + + has_shift = true; + } + +private: + bool has_shift = false; + + // set of indices of used cells (i.e. pos[i] != -1, allowed to not have any seq_id) + std::set used; + + std::vector pos; + + // this array accumulates any applied shifts to the pos array since the last reset_shift() call + // this is used to queue multiple updates to the pos array, which in the end can be applied in one go: + // + // cells.pos_add(x, shift_x); + // cells.pos_div(y, shift_y); + // ... + // + // if (cells.has_shift()) { + // for (int i = 0; i < n; ++i) { + // auto shift_i = cells.get_shift(i); + // ... + // } + // cells.reset_shift(); + // } + // + std::vector shift; + + using bits_t = std::bitset; + + // the bitset seq[i] tells us which sequences are currently occupying the i-th cell + std::vector seq; + + // the set seq_pos[s] tells us which positions are currently present for sequence s + // this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache + std::set seq_pos[LLAMA_MAX_SEQ]; + + // helper functions for updating `seq_pos`, once cell at a time: + + // remove cell i + void seq_pos_rm(uint32_t i) { + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + if (seq[i].test(s)) { + seq_pos[s].erase(pos[i]); + } + } + } + + // add cell i + void seq_pos_add(uint32_t i) { + for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + if (seq[i].test(s)) { + seq_pos[s].insert(pos[i]); + } + } + } +}; diff --git a/src/llama-memory.cpp b/src/llama-memory.cpp index 10173253e..f1107672c 100644 --- a/src/llama-memory.cpp +++ b/src/llama-memory.cpp @@ -1 +1,42 @@ #include "llama-memory.h" + +llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1) { + bool has_update = false; + + switch (s0) { + case LLAMA_MEMORY_STATUS_SUCCESS: + { + has_update = true; + break; + } + case LLAMA_MEMORY_STATUS_NO_UPDATE: + { + break; + } + case LLAMA_MEMORY_STATUS_FAILED_PREPARE: + case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: + { + return s0; + } + } + + switch (s1) { + case LLAMA_MEMORY_STATUS_SUCCESS: + { + has_update = true; + break; + } + case LLAMA_MEMORY_STATUS_NO_UPDATE: + { + break; + } + case LLAMA_MEMORY_STATUS_FAILED_PREPARE: + case LLAMA_MEMORY_STATUS_FAILED_COMPUTE: + { + return s1; + } + } + + // if either status has an update, then the combined status has an update + return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE; +} diff --git a/src/llama-memory.h b/src/llama-memory.h index dfa8c4e90..24668f861 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -2,20 +2,116 @@ #include "llama.h" +#include +#include + +struct llama_ubatch; + +class llama_io_write_i; +class llama_io_read_i; + +struct llama_memory_params { + // kv cache + ggml_type type_k; + ggml_type type_v; + + // use full-size SWA cache + bool swa_full; +}; + +enum llama_memory_status { + LLAMA_MEMORY_STATUS_SUCCESS = 0, + LLAMA_MEMORY_STATUS_NO_UPDATE, + LLAMA_MEMORY_STATUS_FAILED_PREPARE, + LLAMA_MEMORY_STATUS_FAILED_COMPUTE, +}; + +// helper function for combining the status of two memory states +// useful for implementing hybrid memory types (e.g. iSWA) +llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1); + +// the interface for managing the memory state during batch processing +// this interface is implemented per memory type. see: +// - llama_kv_cache_unified_state +// - llama_kv_cache_unified_iswa_state +// ... +// +// the only method that can mutate the memory and the memory state is llama_memory_i::apply() +// +// TODO: rename to llama_memory_context_i ? +struct llama_memory_state_i { + virtual ~llama_memory_state_i() = default; + + // consume the current ubatch from the state and proceed to the next one + // return false if we are done + virtual bool next() = 0; + + // apply the memory state for the current ubatch to the memory object + // return false on failure + virtual bool apply() = 0; + + // TODO: this might get reworked in the future when refactoring llama_batch + virtual std::vector & out_ids() = 0; + + // get the current ubatch + virtual const llama_ubatch & get_ubatch() const = 0; + + // get the status of the memory state - used for error handling and checking if any updates would be applied + virtual llama_memory_status get_status() const = 0; +}; + +using llama_memory_state_ptr = std::unique_ptr; + // general concept of LLM memory // the KV cache is a type of LLM memory, but there can be other types -class llama_memory_i { -public: - virtual void clear() = 0; - virtual void defrag() = 0; +struct llama_memory_i { + virtual ~llama_memory_i() = default; + + // split the input batch into a set of ubatches and verify that they can fit into the cache + // return a state object containing the ubatches and KV cache state required to process them + // check the llama_memory_state_i::get_status() for the result + virtual llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_all) = 0; + + // simulate full cache, used for allocating worst-case compute buffers + virtual llama_memory_state_ptr init_full() = 0; + + // prepare for any pending memory updates, such as shifts, defrags, etc. + // status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update + virtual llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) = 0; + + // getters + virtual bool get_can_shift() const = 0; + + // + // ops + // + + // if data == true, the data buffers will also be cleared together with the metadata + virtual void clear(bool data) = 0; virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0; virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0; virtual void seq_keep(llama_seq_id seq_id) = 0; - virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) = 0; + virtual void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) = 0; virtual void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) = 0; + virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0; virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0; - virtual bool get_can_edit() const = 0; + // + // state write/read + // + + virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0; + virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0; +}; + +using llama_memory_ptr = std::unique_ptr; + +// TODO: temporary until the llama_kv_cache is removed from the public API +struct llama_kv_cache : public llama_memory_i { + virtual ~llama_kv_cache() = default; }; diff --git a/src/llama-mmap.cpp b/src/llama-mmap.cpp index 9da97f1bc..47497cf95 100644 --- a/src/llama-mmap.cpp +++ b/src/llama-mmap.cpp @@ -401,7 +401,7 @@ struct llama_mmap::impl { } } #else - throw std::runtime_error("PrefetchVirtualMemory unavailable"); + LLAMA_LOG_DEBUG("skipping PrefetchVirtualMemory because _WIN32_WINNT < 0x602\n"); #endif } } diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index ea73a8a7b..bd9e6da88 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -288,9 +288,10 @@ namespace GGUFMeta { template bool llama_model_loader::get_arr(const std::string & key, std::vector & result, bool required) { - const int kid = gguf_find_key(meta.get(), key.c_str()); + const gguf_context * ctx = meta.get(); + const int kid = gguf_find_key(ctx, key.c_str()); - if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) { + if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) { if (required) { throw std::runtime_error(format("array key not found in model: %s", key.c_str())); } @@ -298,28 +299,40 @@ namespace GGUFMeta { } struct GGUFMeta::ArrayInfo arr_info = - GGUFMeta::GKV::get_kv(meta.get(), kid); + GGUFMeta::GKV::get_kv(ctx, kid); switch (arr_info.gt) { - case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; - case GGUF_TYPE_INT32: GGML_ASSERT( - (std::is_same::value) || - (std::is_same::value)); break; + case GGUF_TYPE_UINT32: + case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same::value) || + (std::is_same::value)); break; + case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; + case GGUF_TYPE_STRING: GGML_ASSERT((std::is_same::value)); break; default: - throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str())); + throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str())); } - result.resize(arr_info.length); - result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length); + if constexpr (std::is_same::value) { + const size_t n_items = gguf_get_arr_n(ctx, kid); + result.clear(); + + for (size_t i = 0; i < n_items; i++) { + const T value = gguf_get_arr_str(ctx, kid, i); + result.emplace_back(value); + } + } else { + result.resize(arr_info.length); + result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length); + } return true; } template bool llama_model_loader::get_arr(const std::string & key, std::array & result, bool required) { - const int kid = gguf_find_key(meta.get(), key.c_str()); + const gguf_context * ctx = meta.get(); + const int kid = gguf_find_key(ctx, key.c_str()); - if (kid < 0 || gguf_get_kv_type(meta.get(), kid) != GGUF_TYPE_ARRAY) { + if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) { if (required) { throw std::runtime_error(format("array key not found in model: %s", key.c_str())); } @@ -327,22 +340,32 @@ namespace GGUFMeta { } struct GGUFMeta::ArrayInfo arr_info = - GGUFMeta::GKV::get_kv(meta.get(), kid); + GGUFMeta::GKV::get_kv(ctx, kid); switch (arr_info.gt) { - case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; - case GGUF_TYPE_INT32: GGML_ASSERT( - (std::is_same::value) || - (std::is_same::value)); break; + case GGUF_TYPE_UINT32: + case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same::value) || + (std::is_same::value)); break; + case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; + case GGUF_TYPE_STRING: GGML_ASSERT((std::is_same::value)); break; default: - throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str())); + throw std::runtime_error(format("%s is not a string/float32/uint32/int32 array", key.c_str())); } if (arr_info.length > N_MAX) { throw std::runtime_error(format("array length %u for key %s exceeds max %u", (uint32_t) arr_info.length, key.c_str(), (uint32_t) N_MAX)); } - std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); + if constexpr (std::is_same::value) { + const size_t n_items = gguf_get_arr_n(ctx, kid); + + for (size_t i = 0; i < n_items; i++) { + const T value = gguf_get_arr_str(ctx, kid, i); + result[i] = value; + } + } else { + std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); + } return true; } @@ -352,6 +375,8 @@ namespace GGUFMeta { return get_arr(llm_kv(kid), result, required); } + template bool llama_model_loader::get_arr>(enum llm_kv kid, std::vector & result, bool required); + template bool llama_model_loader::get_key(const std::string & key, T & result, bool required) { auto it = kv_overrides.find(key); @@ -469,7 +494,7 @@ llama_model_loader::llama_model_loader( meta.reset(gguf_init_from_file(fname.c_str(), params)); if (!meta) { - throw std::runtime_error(format("%s: failed to load model from %s\n", __func__, fname.c_str())); + throw std::runtime_error(format("%s: failed to load model from %s", __func__, fname.c_str())); } get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); @@ -528,7 +553,7 @@ llama_model_loader::llama_model_loader( }; gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) }; if (!ctx_gguf) { - throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, fname_split)); + throw std::runtime_error(format("%s: failed to load GGUF split from %s", __func__, fname_split)); } // check idx @@ -822,9 +847,18 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps mappings.reserve(files.size()); mmaps_used.reserve(files.size()); for (const auto & file : files) { - auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU)); - auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa"); - std::unique_ptr mapping = std::make_unique(file.get(), prefetch ? -1 : 0, is_numa_fn()); + bool is_numa = false; + + auto * dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (dev) { + auto * reg = ggml_backend_dev_backend_reg(dev); + auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa"); + if (is_numa_fn) { + is_numa = is_numa_fn(); + } + } + + std::unique_ptr mapping = std::make_unique(file.get(), prefetch ? -1 : 0, is_numa); mmaps_used.emplace_back(mapping->size(), 0); if (mlock_mmaps) { std::unique_ptr mlock_mmap(new llama_mlock()); diff --git a/src/llama-model-saver.cpp b/src/llama-model-saver.cpp new file mode 100644 index 000000000..a70b98923 --- /dev/null +++ b/src/llama-model-saver.cpp @@ -0,0 +1,281 @@ +#include "llama-model-saver.h" + +#include "gguf.h" + +#include "llama.h" +#include "llama-hparams.h" +#include "llama-model.h" +#include "llama-vocab.h" + +#include + +llama_model_saver::llama_model_saver(const struct llama_model & model) : model(model), llm_kv(model.arch) { + gguf_ctx = gguf_init_empty(); +} + +llama_model_saver::~llama_model_saver() { + gguf_free(gguf_ctx); +} + +void llama_model_saver::add_kv(const enum llm_kv key, const uint32_t value) { + gguf_set_val_u32(gguf_ctx, llm_kv(key).c_str(), value); +} + +void llama_model_saver::add_kv(const enum llm_kv key, const int32_t value) { + gguf_set_val_i32(gguf_ctx, llm_kv(key).c_str(), value); +} + +void llama_model_saver::add_kv(const enum llm_kv key, const float value) { + gguf_set_val_f32(gguf_ctx, llm_kv(key).c_str(), value); +} + +void llama_model_saver::add_kv(const enum llm_kv key, const bool value) { + gguf_set_val_bool(gguf_ctx, llm_kv(key).c_str(), value); +} + +void llama_model_saver::add_kv(const enum llm_kv key, const char * value) { + gguf_set_val_str(gguf_ctx, llm_kv(key).c_str(), value); +} + +[[noreturn]] +void llama_model_saver::add_kv(const enum llm_kv key, const char value) { + GGML_UNUSED(key); + GGML_UNUSED(value); + GGML_ABORT("fatal error"); // this should never be called, only needed to make the template below compile +} + +template +void llama_model_saver::add_kv(const enum llm_kv key, const Container & value, const bool per_layer) { + const size_t n_values = per_layer ? size_t(model.hparams.n_layer) : value.size(); + GGML_ASSERT(n_values <= value.size()); + + if (n_values == 0) { + return; + } + + if (per_layer) { + bool all_values_the_same = true; + for (size_t i = 1; i < n_values; ++i) { + if (value[i] != value[0]) { + all_values_the_same = false; + break; + } + } + if (all_values_the_same) { + add_kv(key, value[0]); + return; + } + } + + if (std::is_same::value) { + gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT8, value.data(), n_values); + } else if (std::is_same::value) { + gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT8, value.data(), n_values); + } else if (std::is_same::value) { + gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT32, value.data(), n_values); + } else if (std::is_same::value) { + gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT32, value.data(), n_values); + } else if (std::is_same::value) { + gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_FLOAT32, value.data(), n_values); + } else if (std::is_same::value) { + gguf_set_val_str(gguf_ctx, llm_kv(key).c_str(), reinterpret_cast(value.data())); + } else { + GGML_ABORT("fatal error"); + } +} + +void llama_model_saver::add_kv(const enum llm_kv key, const std::vector & value) { + std::vector tmp(value.size()); + for (size_t i = 0; i < value.size(); ++i) { + tmp[i] = value[i].c_str(); + } + gguf_set_arr_str(gguf_ctx, llm_kv(key).c_str(), tmp.data(), tmp.size()); +} + +void llama_model_saver::add_tensor(const struct ggml_tensor * tensor) { + if (!tensor) { + return; + } + if (gguf_find_tensor(gguf_ctx, tensor->name) >= 0) { + GGML_ASSERT(std::string(tensor->name) == "rope_freqs.weight"); // FIXME + return; + } + gguf_add_tensor(gguf_ctx, tensor); +} + +void llama_model_saver::add_kv_from_model() { + const llama_hparams & hparams = model.hparams; + const llama_vocab & vocab = model.vocab; + + const int32_t n_vocab = vocab.n_tokens(); + std::vector tokens(n_vocab); + std::vector scores(n_vocab); + std::vector token_types(n_vocab); + + for (int32_t id = 0; id < n_vocab; ++id) { + const llama_vocab::token_data & token_data = vocab.get_token_data(id); + + tokens[id] = token_data.text; + scores[id] = token_data.score; + + switch(token_data.attr) { + case LLAMA_TOKEN_ATTR_UNKNOWN: token_types[id] = LLAMA_TOKEN_TYPE_UNKNOWN; break; + case LLAMA_TOKEN_ATTR_UNUSED: token_types[id] = LLAMA_TOKEN_TYPE_UNUSED; break; + case LLAMA_TOKEN_ATTR_NORMAL: token_types[id] = LLAMA_TOKEN_TYPE_NORMAL; break; + case LLAMA_TOKEN_ATTR_CONTROL: token_types[id] = LLAMA_TOKEN_TYPE_CONTROL; break; + case LLAMA_TOKEN_ATTR_USER_DEFINED: token_types[id] = LLAMA_TOKEN_TYPE_USER_DEFINED; break; + case LLAMA_TOKEN_ATTR_BYTE: token_types[id] = LLAMA_TOKEN_TYPE_BYTE; break; + case LLAMA_TOKEN_ATTR_UNDEFINED: + default: token_types[id] = LLAMA_TOKEN_TYPE_UNDEFINED; break; + } + } + + // add_kv(LLM_KV_GENERAL_TYPE, ???); + add_kv(LLM_KV_GENERAL_ARCHITECTURE, model.arch_name()); + // add_kv(LLM_KV_GENERAL_QUANTIZATION_VERSION, ???); + // add_kv(LLM_KV_GENERAL_ALIGNMENT, ???); + add_kv(LLM_KV_GENERAL_NAME, model.name); + // add_kv(LLM_KV_GENERAL_AUTHOR, ???); + // add_kv(LLM_KV_GENERAL_VERSION, ???); + // add_kv(LLM_KV_GENERAL_URL, ???); + // add_kv(LLM_KV_GENERAL_DESCRIPTION, ???); + // add_kv(LLM_KV_GENERAL_LICENSE, ???); + // add_kv(LLM_KV_GENERAL_SOURCE_URL, ???); + // add_kv(LLM_KV_GENERAL_SOURCE_HF_REPO, ???); + + add_kv(LLM_KV_VOCAB_SIZE, vocab.n_tokens()); + add_kv(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); + add_kv(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); + add_kv(LLM_KV_BLOCK_COUNT, hparams.n_layer); + add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + add_kv(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, true); + add_kv(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + add_kv(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + add_kv(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); + // add_kv(LLM_KV_TENSOR_DATA_LAYOUT, ???); + add_kv(LLM_KV_EXPERT_COUNT, hparams.n_expert); + add_kv(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + add_kv(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + add_kv(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + add_kv(LLM_KV_POOLING_TYPE, uint32_t(hparams.pooling_type)); + add_kv(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + add_kv(LLM_KV_DECODER_START_TOKEN_ID, hparams.dec_start_token_id); + add_kv(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping); + add_kv(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping); + add_kv(LLM_KV_SWIN_NORM, hparams.swin_norm); + add_kv(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers); + add_kv(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim); + add_kv(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); + add_kv(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); + add_kv(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); + + add_kv(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, true); + add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, true); + add_kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); + add_kv(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); + add_kv(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k); + add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v); + add_kv(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + add_kv(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + add_kv(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + add_kv(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + add_kv(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); + add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + add_kv(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); + + const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train; + + add_kv(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot); + add_kv(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train); + // add_kv(LLM_KV_ROPE_SCALE_LINEAR, rope_scaling_factor); // old name + add_kv(LLM_KV_ROPE_SCALING_TYPE, llama_rope_scaling_type_name(hparams.rope_scaling_type_train)); + add_kv(LLM_KV_ROPE_SCALING_FACTOR, rope_scaling_factor); + add_kv(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor); + add_kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_ctx_orig_yarn); + add_kv(LLM_KV_ROPE_SCALING_FINETUNED, hparams.rope_finetuned); + add_kv(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); + + // TODO: implement split file support + // add_kv(LLM_KV_SPLIT_NO, ???); + // add_kv(LLM_KV_SPLIT_COUNT, ???); + // add_kv(LLM_KV_SPLIT_TENSORS_COUNT, ???); + + add_kv(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + add_kv(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + add_kv(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + add_kv(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + add_kv(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms); + + add_kv(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); + + add_kv(LLM_KV_TOKENIZER_MODEL, vocab.get_tokenizer_model()); + add_kv(LLM_KV_TOKENIZER_PRE, vocab.get_tokenizer_pre()); + add_kv(LLM_KV_TOKENIZER_LIST, tokens); + add_kv(LLM_KV_TOKENIZER_TOKEN_TYPE, token_types); + add_kv(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, vocab.n_token_types()); + add_kv(LLM_KV_TOKENIZER_SCORES, scores); + add_kv(LLM_KV_TOKENIZER_MERGES, vocab.get_bpe_merges()); + // FIXME llama_token is type i32 but when reading in a GGUF file u32 is expected, not an issue for writing though + add_kv(LLM_KV_TOKENIZER_BOS_ID, uint32_t(vocab.token_bos())); + add_kv(LLM_KV_TOKENIZER_EOS_ID, uint32_t(vocab.token_eos())); + add_kv(LLM_KV_TOKENIZER_EOT_ID, uint32_t(vocab.token_eot())); + add_kv(LLM_KV_TOKENIZER_EOM_ID, uint32_t(vocab.token_eom())); + add_kv(LLM_KV_TOKENIZER_UNK_ID, uint32_t(vocab.token_unk())); + add_kv(LLM_KV_TOKENIZER_SEP_ID, uint32_t(vocab.token_sep())); + add_kv(LLM_KV_TOKENIZER_PAD_ID, uint32_t(vocab.token_pad())); + // add_kv(LLM_KV_TOKENIZER_CLS_ID, uint32_t(vocab.token_bos())); // deprecated + // add_kv(LLM_KV_TOKENIZER_MASK_ID, ???); + add_kv(LLM_KV_TOKENIZER_ADD_BOS, vocab.get_add_bos()); + add_kv(LLM_KV_TOKENIZER_ADD_EOS, vocab.get_add_eos()); + add_kv(LLM_KV_TOKENIZER_ADD_PREFIX, vocab.get_add_space_prefix()); + add_kv(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.get_remove_extra_whitespaces()); + add_kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, vocab.get_precompiled_charsmap()); + // add_kv(LLM_KV_TOKENIZER_HF_JSON, ???); + // add_kv(LLM_KV_TOKENIZER_RWKV, ???); + add_kv(LLM_KV_TOKENIZER_FIM_PRE_ID, uint32_t(vocab.token_fim_pre())); + add_kv(LLM_KV_TOKENIZER_FIM_SUF_ID, uint32_t(vocab.token_fim_suf())); + add_kv(LLM_KV_TOKENIZER_FIM_MID_ID, uint32_t(vocab.token_fim_mid())); + add_kv(LLM_KV_TOKENIZER_FIM_PAD_ID, uint32_t(vocab.token_fim_pad())); + add_kv(LLM_KV_TOKENIZER_FIM_REP_ID, uint32_t(vocab.token_fim_rep())); + add_kv(LLM_KV_TOKENIZER_FIM_SEP_ID, uint32_t(vocab.token_fim_sep())); + + // TODO: implement LoRA support + // add_kv(LLM_KV_ADAPTER_TYPE, ???); + // add_kv(LLM_KV_ADAPTER_LORA_ALPHA, ???); + + // deprecated + // add_kv(LLM_KV_TOKENIZER_PREFIX_ID, ???); + // add_kv(LLM_KV_TOKENIZER_SUFFIX_ID, ???); + // add_kv(LLM_KV_TOKENIZER_MIDDLE_ID, ???); +} + +void llama_model_saver::add_tensors_from_model() { + if (std::string(model.output->name) != std::string(model.tok_embd->name)) { + add_tensor(model.tok_embd); // some models use the same tensor for tok_embd and output + } + add_tensor(model.type_embd); + add_tensor(model.pos_embd); + add_tensor(model.tok_norm); + add_tensor(model.tok_norm_b); + add_tensor(model.output_norm); + add_tensor(model.output_norm_b); + add_tensor(model.output); + add_tensor(model.output_b); + add_tensor(model.output_norm_enc); + add_tensor(model.cls); + add_tensor(model.cls_b); + add_tensor(model.cls_out); + add_tensor(model.cls_out_b); + + for (const struct llama_layer & layer : model.layers) { + for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) { + add_tensor(reinterpret_cast(&layer)[i]); + } + } +} + +void llama_model_saver::save(const std::string & path_model) { + gguf_write_to_file(gguf_ctx, path_model.c_str(), false); +} + diff --git a/src/llama-model-saver.h b/src/llama-model-saver.h new file mode 100644 index 000000000..a5a434c30 --- /dev/null +++ b/src/llama-model-saver.h @@ -0,0 +1,37 @@ +#pragma once + +#include "llama.h" +#include "llama-arch.h" + +#include + +struct llama_model_saver { + struct gguf_context * gguf_ctx = nullptr; + const struct llama_model & model; + const struct LLM_KV llm_kv; + + llama_model_saver(const struct llama_model & model); + ~llama_model_saver(); + + void add_kv(enum llm_kv key, uint32_t value); + void add_kv(enum llm_kv key, int32_t value); + void add_kv(enum llm_kv key, float value); + void add_kv(enum llm_kv key, bool value); + void add_kv(enum llm_kv key, const char * value); + + [[noreturn]] + void add_kv(enum llm_kv key, char value); // needed to make the template below compile + + template + void add_kv(enum llm_kv key, const Container & value, bool per_layer = false); + + void add_kv(enum llm_kv key, const std::vector & value); + + void add_tensor(const struct ggml_tensor * tensor); + + void add_kv_from_model(); + + void add_tensors_from_model(); + + void save(const std::string & path_model); +}; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b74dd72cf..a5eb122f9 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -5,7 +5,10 @@ #include "llama-batch.h" #include "llama-cparams.h" #include "llama-model-loader.h" -#include "llama-kv-cache.h" + +#include "llama-kv-cache-unified.h" +#include "llama-kv-cache-unified-iswa.h" +#include "llama-kv-cache-recurrent.h" #include "ggml-cpp.h" @@ -40,14 +43,17 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_335M: return "335M"; case LLM_TYPE_410M: return "410M"; case LLM_TYPE_450M: return "450M"; + case LLM_TYPE_475M: return "475M"; case LLM_TYPE_770M: return "770M"; case LLM_TYPE_780M: return "780M"; case LLM_TYPE_0_5B: return "0.5B"; + case LLM_TYPE_0_6B: return "0.6B"; case LLM_TYPE_1B: return "1B"; case LLM_TYPE_1_3B: return "1.3B"; case LLM_TYPE_1_4B: return "1.4B"; case LLM_TYPE_1_5B: return "1.5B"; case LLM_TYPE_1_6B: return "1.6B"; + case LLM_TYPE_1_7B: return "1.7B"; case LLM_TYPE_1_8B: return "1.8B"; case LLM_TYPE_2B: return "2B"; case LLM_TYPE_2_8B: return "2.8B"; @@ -66,6 +72,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_15B: return "15B"; case LLM_TYPE_16B: return "16B"; case LLM_TYPE_20B: return "20B"; + case LLM_TYPE_27B: return "27B"; case LLM_TYPE_30B: return "30B"; case LLM_TYPE_32B: return "32B"; case LLM_TYPE_34B: return "34B"; @@ -73,8 +80,11 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_40B: return "40B"; case LLM_TYPE_65B: return "65B"; case LLM_TYPE_70B: return "70B"; + case LLM_TYPE_142B: return "142B"; case LLM_TYPE_236B: return "236B"; + case LLM_TYPE_290B: return "290B"; case LLM_TYPE_314B: return "314B"; + case LLM_TYPE_405B: return "405B"; case LLM_TYPE_671B: return "671B"; case LLM_TYPE_SMALL: return "0.1B"; case LLM_TYPE_MEDIUM: return "0.4B"; @@ -88,10 +98,10 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_16x3_8B: return "16x3.8B"; case LLM_TYPE_10B_128x3_66B: return "10B+128x3.66B"; case LLM_TYPE_57B_A14B: return "57B.A14B"; - case LLM_TYPE_27B: return "27B"; - case LLM_TYPE_290B: return "290B"; case LLM_TYPE_17B_16E: return "17Bx16E (Scout)"; case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)"; + case LLM_TYPE_30B_A3B: return "30B.A3B"; + case LLM_TYPE_235B_A22B: return "235B.A22B"; default: return "?B"; } } @@ -111,6 +121,10 @@ static const std::map LLAMA_ROPE_SCALING_ { LLAMA_ROPE_SCALING_TYPE_LONGROPE, "longrope" }, }; +std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type) { + return LLAMA_ROPE_SCALING_TYPES.at(rope_scaling_type); +} + static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) { for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) { if (kv.second == name) { @@ -293,6 +307,10 @@ static buft_list_t make_cpu_buft_list(const std::vector & de // add extra buffer types, only if no GPU device is present // ref: https://github.com/ggml-org/llama.cpp/issues/12481#issuecomment-2743136094 auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (cpu_dev == nullptr) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); @@ -449,11 +467,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { GGML_ASSERT(hparams.n_expert_used == 0); } - // zero-out the array hparams std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); + std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0); + + std::fill(hparams.swa_layers.begin(), hparams.swa_layers.end(), 0); + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); @@ -523,6 +544,12 @@ void llama_model::load_hparams(llama_model_loader & ml) { uint32_t n_vocab = 0; ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false); + // for classifier models + ml.get_arr(LLM_KV_CLASSIFIER_OUTPUT_LABELS, classifier_labels, false); + if (!classifier_labels.empty()) { + hparams.n_cls_out = classifier_labels.size(); + } + // arch-specific KVs switch (arch) { case LLM_ARCH_LLAMA: @@ -557,9 +584,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); - hparams.n_swa_pattern = 4; // pattern: 3 chunked - 1 full - hparams.n_attn_chunk = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick - hparams.n_swa = 1; // TODO @ngxson : this is added to trigger the SWA branch (we store the chunked attn mask in the SWA tensor), will need to clean this up later + + hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; + hparams.n_swa = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick + hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full switch (hparams.n_expert) { case 16: type = LLM_TYPE_17B_16E; break; @@ -571,12 +599,23 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.use_kq_norm = false; } } break; + case LLM_ARCH_ARCEE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // Arcee uses the same structure as Llama + switch (hparams.n_layer) { + case 36: type = LLM_TYPE_4B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_DECI: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { case 32: type = LLM_TYPE_7B; break; case 80: type = LLM_TYPE_70B; break; + case 162: type = LLM_TYPE_405B; break; default: type = LLM_TYPE_UNKNOWN; } } break; @@ -695,13 +734,29 @@ void llama_model::load_hparams(llama_model_loader & ml) { } } break; case LLM_ARCH_NOMIC_BERT: + case LLM_ARCH_NOMIC_BERT_MOE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers, 0); if (hparams.n_layer == 12 && hparams.n_embd == 768) { - type = LLM_TYPE_137M; + if (arch == LLM_ARCH_NOMIC_BERT) { + type = LLM_TYPE_137M; + } else if (arch == LLM_ARCH_NOMIC_BERT_MOE && hparams.moe_every_n_layers == 2) { + type = LLM_TYPE_475M; + } + } + } break; + case LLM_ARCH_NEO_BERT: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + + if (hparams.n_layer == 28) { + type = LLM_TYPE_250M; } } break; case LLM_ARCH_BLOOM: @@ -762,6 +817,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // fall through case LLM_ARCH_QWEN2: { + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break; @@ -791,6 +847,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { + case 28: type = hparams.n_embd == 1024 ? LLM_TYPE_0_6B : LLM_TYPE_1_7B; break; + case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break; + case 40: type = LLM_TYPE_14B; break; + case 64: type = LLM_TYPE_32B; break; default: type = LLM_TYPE_UNKNOWN; } } break; @@ -800,6 +860,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { + case 48: type = LLM_TYPE_30B_A3B; break; + case 94: type = LLM_TYPE_235B_A22B; break; default: type = LLM_TYPE_UNKNOWN; } } break; @@ -824,22 +886,17 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } - // for backward compatibility ; see: https://github.com/ggerganov/llama.cpp/pull/8931 - if ((hparams.n_layer == 32 || hparams.n_layer == 40) && hparams.n_ctx_train == 4096) { - // default value for Phi-3-mini-4k-instruct and Phi-3-medium-4k-instruct - hparams.n_swa = 2047; - } else if (hparams.n_layer == 32 && hparams.n_head_kv(0) == 32 && hparams.n_ctx_train == 131072) { - // default value for Phi-3-mini-128k-instruct - // note: this seems incorrect because the window is bigger than the train context? - hparams.n_swa = 262144; - } else if (hparams.n_layer == 40 && hparams.n_ctx_train == 131072) { - // default value for Phi-3-medium-128k-instruct - // note: this seems incorrect because the window is equal to the train context? - hparams.n_swa = 131072; - } - bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - if (!found_swa && hparams.n_swa == 0) { - throw std::runtime_error("invalid value for sliding_window"); + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + + if (found_swa && hparams.n_swa > 0) { + LLAMA_LOG_WARN("%s: Phi SWA is currently disabled - results might be suboptimal for some models (see %s)\n", + __func__, "https://github.com/ggml-org/llama.cpp/pull/13676"); + + // TODO: fix conversion scripts to correctly populate `n_swa` and `n_swa_pattern` + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + + hparams.n_swa = 0; + hparams.set_swa_pattern(1); } } break; case LLM_ARCH_PHIMOE: @@ -909,8 +966,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_GEMMA2: { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.n_swa = 4096; // default value of gemma 2 - hparams.n_swa_pattern = 2; + hparams.set_swa_pattern(2); hparams.attn_soft_cap = true; ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); @@ -924,10 +982,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { case 46: type = LLM_TYPE_27B; break; default: type = LLM_TYPE_UNKNOWN; } + + // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173 + hparams.f_attention_scale = type == LLM_TYPE_27B + ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) + : 1.0f / std::sqrt(float(hparams.n_embd_head_k)); } break; case LLM_ARCH_GEMMA3: { - hparams.n_swa_pattern = 6; + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.set_swa_pattern(6); hparams.rope_freq_base_train_swa = 10000.0f; hparams.rope_freq_scale_train_swa = 1.0f; @@ -943,6 +1007,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } + // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289 hparams.f_attention_scale = type == LLM_TYPE_27B ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) : 1.0f / std::sqrt(float(hparams.n_embd_head_k)); @@ -1011,7 +1076,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_COHERE2: { - hparams.n_swa_pattern = 4; + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.set_swa_pattern(4); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); @@ -1156,6 +1222,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); } ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); @@ -1359,6 +1427,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { // Add additional layer/vocab/etc checks here for other model sizes default: type = LLM_TYPE_UNKNOWN; } + + // For Granite MoE Shared + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); } break; case LLM_ARCH_CHAMELEON: { @@ -1394,6 +1465,20 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_DOTS1: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + switch (hparams.n_layer) { + case 62: type = LLM_TYPE_142B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -1462,6 +1547,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (cpu_dev == nullptr) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0); const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1); auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev { @@ -1629,8 +1717,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) { std::regex pattern(overrides->pattern); if (std::regex_search(tensor_name, pattern)) { - LLAMA_LOG_DEBUG("tensor %s buffer type overriden to %s\n", tensor_name.c_str(), ggml_backend_buft_name(overrides->buft)); buft = overrides->buft; + LLAMA_LOG_DEBUG("tensor %s (%zu MiB %s) buffer type overridden to %s\n", + tensor_name.c_str(), + ggml_nbytes(t_meta) / 1024 / 1024, ggml_type_name(t_meta->type), + ggml_backend_buft_name(buft)); break; } } @@ -1647,6 +1738,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { auto * buft_dev = ggml_backend_buft_get_device(buft); if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) { auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!cpu_dev) { + throw std::runtime_error("no CPU backend found"); + } buft = ggml_backend_dev_buffer_type(cpu_dev); } @@ -1733,6 +1827,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } } } } break; @@ -1828,7 +1929,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + if (n_ff > 0) { + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + } if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); @@ -1838,9 +1941,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); } - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + if (n_ff > 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } // optional MLP bias layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); @@ -2055,9 +2160,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_BERT: case LLM_ARCH_NOMIC_BERT: + case LLM_ARCH_NOMIC_BERT_MOE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); if (arch == LLM_ARCH_BERT) { pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); @@ -2065,8 +2171,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); - cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED); - cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {1}, TENSOR_NOT_REQUIRED); + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); } tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); @@ -2075,7 +2181,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - if (arch == LLM_ARCH_BERT) { + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + + if (!layer.wqkv) { layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); @@ -2084,8 +2193,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - } else { - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); } layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); @@ -2093,21 +2200,54 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - - if (arch == LLM_ARCH_BERT) { + if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) { layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); } else { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + + if (arch == LLM_ARCH_BERT || arch == LLM_ARCH_NOMIC_BERT_MOE) { + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + } else { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + } } layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); } } break; + case LLM_ARCH_NEO_BERT: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + + output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff*2}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + } + } break; case LLM_ARCH_JINA_BERT_V2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings @@ -2145,8 +2285,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, layer.ffn_gate ? n_ff : n_ff * 2}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); @@ -2422,7 +2562,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -3205,8 +3349,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { { const bool is_lite = (hparams.n_layer == 27); + const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k_mla = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k; + const int64_t n_embd_head_v_mla = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v; + const int64_t n_embd_head_qk_rope = hparams.n_rot; - const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; const int64_t q_lora_rank = hparams.n_lora_q; const int64_t kv_lora_rank = hparams.n_lora_kv; @@ -3232,14 +3382,22 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (!is_lite) { layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); - layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); } else { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); } - layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); - layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, 0); + + // note: only old legacy GGUF files will have the unsplit wkv_b tensor in + if (is_mla) { + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); + } else { + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v_mla)}, 0); + } + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -3458,7 +3616,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -4022,6 +4184,89 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); } } break; + case LLM_ARCH_DOTS1: + { + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } + } break; + case LLM_ARCH_ARCEE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -4063,6 +4308,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (!dev) { // FIXME: workaround for CPU backend buft having a NULL device dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!dev) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } } ggml_backend_dev_props props; ggml_backend_dev_get_props(dev, &props); @@ -4192,7 +4440,7 @@ uint64_t llama_model::n_elements() const { } void llama_model::print_info() const { - const char * rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train); + const std::string rope_scaling_type = llama_rope_scaling_type_name(hparams.rope_scaling_type_train); auto print_f = [](const std::function & f, uint32_t n) { bool is_var = false; @@ -4235,7 +4483,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); - LLAMA_LOG_INFO("%s: n_swa_pattern = %u\n", __func__, hparams.n_swa_pattern); + LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any()); LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k); LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v); LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); @@ -4253,7 +4501,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn); LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type); LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type); - LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type); + LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str()); LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); @@ -4263,6 +4511,15 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); + + if (!classifier_labels.empty()) { + LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out); + + size_t i = 0; + for (auto label : classifier_labels) { + LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str()); + } + } } LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); @@ -4290,6 +4547,8 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); + LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla); + LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla); LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); @@ -4307,10 +4566,13 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); } - if (arch == LLM_ARCH_MINICPM || arch == LLM_ARCH_GRANITE || arch == LLM_ARCH_GRANITE_MOE) { + if (arch == LLM_ARCH_MINICPM || + arch == LLM_ARCH_GRANITE || + arch == LLM_ARCH_GRANITE_MOE) { LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } if (arch == LLM_ARCH_BAILINGMOE) { @@ -4398,6 +4660,29 @@ const ggml_tensor * llama_model::get_tensor(const char * name) const { return it->second; } +float llama_model::get_rope_freq_base (const llama_cparams & cparams, int il) const { + return hparams.is_swa(il) ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base; +} + +float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) const { + return hparams.is_swa(il) ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale; +} + +ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const { + const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; + + // choose long/short freq factors based on the context size + if (layers[il].rope_freqs != nullptr) { + return layers[il].rope_freqs; + } + + if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) { + return layers[il].rope_long; + } + + return layers[il].rope_short; +} + struct llm_build_llama : public llm_graph_context { llm_build_llama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -4413,22 +4698,13 @@ struct llm_build_llama : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - // temperature tuning - ggml_tensor * inp_attn_scale = nullptr; - if (arch == LLM_ARCH_LLAMA4) { - inp_attn_scale = build_inp_attn_scale(); - } - auto * inp_attn = build_attn_inp_kv_unified(); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; - bool use_rope = arch == LLM_ARCH_LLAMA4 - ? (il + 1) % hparams.n_no_rope_layer_step != 0 - : true; - // norm cur = build_norm(inpL, model.layers[il].attn_norm, NULL, @@ -4438,7 +4714,169 @@ struct llm_build_llama : public llm_graph_context { // self-attention { // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network (non-MoE) + if (model.layers[il].ffn_gate_inp == nullptr) { + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(cur, "ffn_moe_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_llama_iswa : public llm_graph_context { + llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // temperature tuning + ggml_tensor * inp_attn_scale = nullptr; + inp_attn_scale = build_inp_attn_scale(); + + auto * inp_attn = build_attn_inp_kv_unified_iswa(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + const bool use_rope = (il + 1) % hparams.n_no_rope_layer_step != 0; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -4486,7 +4924,7 @@ struct llm_build_llama : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - if (arch == LLM_ARCH_LLAMA4 && use_rope && hparams.use_kq_norm) { + if (use_rope && hparams.use_kq_norm) { // Llama4TextL2Norm Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps); Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps); @@ -4496,7 +4934,7 @@ struct llm_build_llama : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -4507,17 +4945,11 @@ struct llm_build_llama : public llm_graph_context { inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - // For Granite architecture - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - } - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); // feed-forward network (non-MoE) if (model.layers[il].ffn_gate_inp == nullptr) { - cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); @@ -4530,9 +4962,7 @@ struct llm_build_llama : public llm_graph_context { NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); - - } else if (arch == LLM_ARCH_LLAMA4) { - // llama4 MoE + } else { ggml_tensor * ffn_inp_normed = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); @@ -4561,31 +4991,6 @@ struct llm_build_llama : public llm_graph_context { cur = ggml_add(ctx0, moe_out, shexp_out); cb(cur, "ffn_moe_out_merged", il); - - } else { - // MoE branch - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - - cur = build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, - model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, - model.layers[il].ffn_down_exps, - nullptr, - n_expert, n_expert_used, - LLM_FFN_SILU, true, - false, 0.0, - LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, - il); - cb(cur, "ffn_moe_out", il); - } - - // For Granite architecture - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); } cur = ggml_add(ctx0, cur, ffn_inp); @@ -4610,11 +5015,6 @@ struct llm_build_llama : public llm_graph_context { // lm_head cur = build_lora_mm(model.output, cur); - // For Granite architecture - if (hparams.f_logit_scale) { - cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); - } - cb(cur, "result_output", -1); res->t_logits = cur; @@ -4644,6 +5044,7 @@ struct llm_build_deci : public llm_graph_context { ggml_tensor * inpSA = inpL; const int64_t n_head_kv = hparams.n_head_kv(il); const int64_t n_head = hparams.n_head(il); + const int64_t n_ff = hparams.n_ff(il); if (n_head == 0) { // attention-free layer of Llama-3_1-Nemotron-51B @@ -4663,7 +5064,7 @@ struct llm_build_deci : public llm_graph_context { } else if (n_head > 0) { // self-attention // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -4709,7 +5110,7 @@ struct llm_build_deci : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1) { @@ -4719,9 +5120,9 @@ struct llm_build_deci : public llm_graph_context { inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - // For Granite architecture - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + // FFN-free layer of Llama-3_1-Nemotron-Ultra-253B + if (n_ff == 0) { + continue; } // modified to support attention-free layer of Llama-3_1-Nemotron-51B @@ -4747,11 +5148,6 @@ struct llm_build_deci : public llm_graph_context { cb(cur, "ffn_out", il); } - // For Granite architecture - if (hparams.f_residual_scale) { - cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); - } - cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); @@ -4774,11 +5170,6 @@ struct llm_build_deci : public llm_graph_context { // lm_head cur = build_lora_mm(model.output, cur); - // For Granite architecture - if (hparams.f_logit_scale) { - cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); - } - cb(cur, "result_output", -1); res->t_logits = cur; @@ -4851,7 +5242,7 @@ struct llm_build_baichuan : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -4966,7 +5357,7 @@ struct llm_build_xverse : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -5091,7 +5482,7 @@ struct llm_build_falcon : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -5221,7 +5612,7 @@ struct llm_build_grok : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1) { @@ -5372,7 +5763,7 @@ struct llm_build_dbrx : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -5486,7 +5877,7 @@ struct llm_build_starcoder : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -5585,7 +5976,7 @@ struct llm_build_refact : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -5661,8 +6052,10 @@ struct llm_build_bert : public llm_graph_context { inpL = build_inp_embd(model.tok_embd); // token types are hardcoded to zero ("Sentence A") - ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0); - inpL = ggml_add(ctx0, inpL, type_row0); + if (model.type_embd) { + ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0); + inpL = ggml_add(ctx0, inpL, type_row0); + } if (model.arch == LLM_ARCH_BERT) { inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL); } @@ -5683,43 +6076,44 @@ struct llm_build_bert : public llm_graph_context { ggml_tensor * Vcur; // self-attention - if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_JINA_BERT_V2) { - Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq); - - if (model.layers[il].attn_q_norm) { - Qcur = build_norm(Qcur, - model.layers[il].attn_q_norm, - model.layers[il].attn_q_norm_b, - LLM_NORM, il); - } - - Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk); - - if (model.layers[il].attn_k_norm) { - Kcur = build_norm(Kcur, - model.layers[il].attn_k_norm, - model.layers[il].attn_k_norm_b, - LLM_NORM, il); - } - - Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - } else { - // compute Q and K and RoPE them + if (model.layers[il].wqkv) { cur = build_lora_mm(model.layers[il].wqkv, cur); cb(cur, "wqkv", il); + if (model.layers[il].bqkv) { + cur = ggml_add(ctx0, cur, model.layers[il].bqkv); + cb(cur, "bqkv", il); + } + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + } else { + Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq); + Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk); + Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv); + } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, + model.layers[il].attn_q_norm, + model.layers[il].attn_q_norm_b, + LLM_NORM, il); + } + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, + model.layers[il].attn_k_norm, + model.layers[il].attn_k_norm_b, + LLM_NORM, il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // RoPE + if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) { Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -5739,7 +6133,7 @@ struct llm_build_bert : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) { @@ -5764,20 +6158,37 @@ struct llm_build_bert : public llm_graph_context { cb(ffn_inp, "ffn_inp", il); // feed-forward network - if (model.arch == LLM_ARCH_BERT) { + if (hparams.moe_every_n_layers > 0 && il % hparams.moe_every_n_layers == 1) { + // MoE branch + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + nullptr, + model.layers[il].ffn_down_exps, + nullptr, + hparams.n_expert, + hparams.n_expert_used, + LLM_FFN_GELU, + false, false, + 0.0f, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); + cb(cur, "ffn_moe_out", il); + } else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) { cur = build_ffn(cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); } else if (model.arch == LLM_ARCH_JINA_BERT_V2) { cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, - LLM_FFN_GELU, LLM_FFN_PAR, il); + model.layers[il].ffn_gate ? LLM_FFN_GELU : LLM_FFN_GEGLU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); } else { cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, @@ -5785,8 +6196,8 @@ struct llm_build_bert : public llm_graph_context { model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); } - cb(cur, "ffn_out", il); // attentions bypass the intermediate layer cur = ggml_add(ctx0, cur, ffn_inp); @@ -5807,6 +6218,117 @@ struct llm_build_bert : public llm_graph_context { } }; +struct llm_build_neo_bert : public llm_graph_context { + llm_build_neo_bert(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + ggml_tensor * inp_pos = build_inp_pos(); + + // construct input embeddings (token, type, position) + inpL = build_inp_embd(model.tok_embd); + cb(inpL, "inp_embd", -1); + + auto * inp_attn = build_attn_inp_no_cache(); + + // iterate layers + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * cur = inpL; + + ggml_tensor * Qcur; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + + // pre-norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + + // self-attention + cur = build_lora_mm(model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd))); + Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd))); + Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // RoPE + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + cb(cur, "kqv_out", il); + + if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // re-add the layer input + cur = ggml_add(ctx0, cur, inpL); + + ggml_tensor * ffn_inp = cur; + cb(ffn_inp, "ffn_inp", il); + + // pre-norm + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + cur = build_ffn(cur, + model.layers[il].ffn_up, + NULL, NULL, NULL, NULL, NULL, + model.layers[il].ffn_down, + NULL, NULL, NULL, + LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); + + // attentions bypass the intermediate layer + cur = ggml_add(ctx0, cur, ffn_inp); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm_enc, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_embd", -1); + res->t_embd = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + struct llm_build_bloom : public llm_graph_context { llm_build_bloom(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -5856,7 +6378,7 @@ struct llm_build_bloom : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -5997,7 +6519,7 @@ struct llm_build_mpt : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6143,7 +6665,7 @@ struct llm_build_stablelm : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6266,7 +6788,7 @@ struct llm_build_qwen : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6386,7 +6908,7 @@ struct llm_build_qwen2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6507,7 +7029,7 @@ struct llm_build_qwen2vl : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6634,7 +7156,7 @@ struct llm_build_qwen2moe : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6787,7 +7309,7 @@ struct llm_build_qwen3 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -6908,7 +7430,7 @@ struct llm_build_qwen3moe : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -7048,7 +7570,7 @@ struct llm_build_phi2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1) { @@ -7100,6 +7622,7 @@ struct llm_build_phi2 : public llm_graph_context { } }; +template struct llm_build_phi3 : public llm_graph_context { llm_build_phi3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -7115,7 +7638,14 @@ struct llm_build_phi3 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + using inp_attn_type = std::conditional_t; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_unified_iswa(); + } else { + inp_attn = build_attn_inp_kv_unified(); + } for (int il = 0; il < n_layer; ++il) { auto * residual = inpL; @@ -7123,7 +7653,7 @@ struct llm_build_phi3 : public llm_graph_context { // self-attention { // rope freq factors for 128k context - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); ggml_tensor* attn_norm_output = build_norm(inpL, model.layers[il].attn_norm, @@ -7177,7 +7707,7 @@ struct llm_build_phi3 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1) { @@ -7312,7 +7842,7 @@ struct llm_build_plamo : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } ggml_tensor * sa_out = cur; @@ -7419,7 +7949,7 @@ struct llm_build_gpt2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -7535,7 +8065,7 @@ struct llm_build_codeshell : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -7664,7 +8194,7 @@ struct llm_build_orion : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -7791,7 +8321,7 @@ struct llm_build_internlm2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -7875,7 +8405,7 @@ struct llm_build_minicpm3 : public llm_graph_context { for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // norm cur = build_norm(inpL, @@ -7988,7 +8518,7 @@ struct llm_build_minicpm3 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - q_states, k_states, v_states, nullptr, kq_scale, il); + q_states, k_states, v_states, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1) { @@ -8118,7 +8648,7 @@ struct llm_build_gemma : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1) { @@ -8175,8 +8705,8 @@ struct llm_build_gemma : public llm_graph_context { } }; -struct llm_build_gemma2 : public llm_graph_context { - llm_build_gemma2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { +struct llm_build_gemma2_iswa : public llm_graph_context { + llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k; ggml_tensor * cur; @@ -8190,7 +8720,7 @@ struct llm_build_gemma2 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv_unified_iswa(); for (int il = 0; il < n_layer; ++il) { // norm @@ -8229,18 +8759,11 @@ struct llm_build_gemma2 : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e - switch (model.type) { - case LLM_TYPE_2B: - case LLM_TYPE_9B: - case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); break; - default: GGML_ABORT("fatal error"); - }; - cb(Qcur, "Qcur_scaled", il); + Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } cur = build_norm(cur, @@ -8312,8 +8835,8 @@ struct llm_build_gemma2 : public llm_graph_context { } }; -struct llm_build_gemma3 : public llm_graph_context { - llm_build_gemma3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { +struct llm_build_gemma3_iswa : public llm_graph_context { + llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k; ggml_tensor * cur; @@ -8331,13 +8854,11 @@ struct llm_build_gemma3 : public llm_graph_context { ggml_tensor * inp_pos = build_inp_pos(); // TODO: is causal == true correct? might need some changes - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv_unified_iswa(); for (int il = 0; il < n_layer; ++il) { - const bool is_swa = hparams.is_swa(il); - - const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base; - const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale; + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); // norm cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); @@ -8379,9 +8900,12 @@ struct llm_build_gemma3 : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); + // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315 + Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); + cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, hparams.f_attention_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } cur = build_norm(cur, @@ -8521,7 +9045,7 @@ struct llm_build_starcoder2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -8588,7 +9112,6 @@ struct llm_build_mamba : public llm_graph_context { inpL = build_inp_embd(model.tok_embd); ggml_tensor * state_copy = build_inp_s_copy(); - ggml_tensor * state_mask = build_inp_s_mask(); for (int il = 0; il < n_layer; ++il) { // norm @@ -8597,8 +9120,7 @@ struct llm_build_mamba : public llm_graph_context { LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - //cur = build_mamba_layer(gf, cur, state_copy, state_mask, il); - cur = build_mamba_layer(gf, cur, state_copy, state_mask, ubatch, il); + cur = build_mamba_layer(gf, cur, state_copy, ubatch, il); if (il == n_layer - 1) { // skip computing output for unused tokens @@ -8639,12 +9161,11 @@ struct llm_build_mamba : public llm_graph_context { ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * state_copy, - ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); - const auto kv_head = kv_self->head; + const auto kv_head = kv_state->get_head(); const int64_t d_conv = hparams.ssm_d_conv; const int64_t d_inner = hparams.ssm_d_inner; @@ -8662,16 +9183,16 @@ struct llm_build_mamba : public llm_graph_context { GGML_ASSERT(ubatch.equal_seqs); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); - ggml_tensor * conv_states_all = kv_self->k_l[il]; - ggml_tensor * ssm_states_all = kv_self->v_l[il]; + ggml_tensor * conv_states_all = kv_state->get_k_l(il); + ggml_tensor * ssm_states_all = kv_state->get_v_l(il); // (ab)using the KV cache to store the states - ggml_tensor * conv = build_copy_mask_state( - gf, conv_states_all, state_copy, state_mask, + ggml_tensor * conv = build_recurrent_state( + gf, conv_states_all, state_copy, hparams.n_embd_k_s(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); - ggml_tensor * ssm = build_copy_mask_state( - gf, ssm_states_all, state_copy, state_mask, + ggml_tensor * ssm = build_recurrent_state( + gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), n_seqs); ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs); @@ -8856,7 +9377,7 @@ struct llm_build_command_r : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -8914,8 +9435,8 @@ struct llm_build_command_r : public llm_graph_context { } }; -struct llm_build_cohere2 : public llm_graph_context { - llm_build_cohere2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { +struct llm_build_cohere2_iswa : public llm_graph_context { + llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -8930,7 +9451,7 @@ struct llm_build_cohere2 : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_unified(); + auto * inp_attn = build_attn_inp_kv_unified_iswa(); for (int il = 0; il < n_layer; ++il) { const bool is_swa = hparams.is_swa(il); @@ -8943,7 +9464,7 @@ struct llm_build_cohere2 : public llm_graph_context { // self-attention { // rope freq factors for 128k context - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -8991,7 +9512,7 @@ struct llm_build_cohere2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -9122,7 +9643,7 @@ struct llm_build_olmo : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, nullptr, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -9242,7 +9763,7 @@ struct llm_build_olmo2 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } cur = build_norm(cur, @@ -9375,7 +9896,7 @@ struct llm_build_olmoe : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -9508,7 +10029,7 @@ struct llm_build_openelm : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -9622,7 +10143,7 @@ struct llm_build_gptneox : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -9772,7 +10293,7 @@ struct llm_build_arctic : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -9881,7 +10402,7 @@ struct llm_build_deepseek : public llm_graph_context { // self-attention { // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -9927,7 +10448,7 @@ struct llm_build_deepseek : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1) { @@ -10017,15 +10538,22 @@ struct llm_build_deepseek2 : public llm_graph_context { llm_build_deepseek2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { bool is_lite = (hparams.n_layer == 27); + const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k; + const int64_t n_embd_head_v = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v; + + const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope; + + const uint32_t kv_lora_rank = hparams.n_lora_kv; + // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly. // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); - const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k)); - const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); - - const uint32_t n_embd_head_qk_rope = hparams.n_rot; - const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; - const uint32_t kv_lora_rank = hparams.n_lora_kv; + const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(n_embd_head_k)); + const float attn_factor = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)); ggml_tensor * cur; ggml_tensor * inpL; @@ -10051,16 +10579,14 @@ struct llm_build_deepseek2 : public llm_graph_context { { ggml_tensor * q = NULL; if (!is_lite) { - // {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens} q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); cb(q, "q", il); q = build_norm(q, - model.layers[il].attn_q_a_norm, NULL, + model.layers[il].attn_q_a_norm, nullptr, LLM_NORM_RMS, il); cb(q, "q", il); - // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens} q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q); cb(q, "q", il); } else { @@ -10068,96 +10594,125 @@ struct llm_build_deepseek2 : public llm_graph_context { cb(q, "q", il); } - // split into {n_head * n_embd_head_qk_nope, n_tokens} - ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(q->type, hparams.n_embd_head_k), - ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + // split into {n_embd_head_qk_nope, n_head, n_tokens} + ggml_tensor * q_nope = ggml_view_3d(ctx0, q, + n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(q->type, n_embd_head_k), + ggml_row_size(q->type, n_embd_head_k) * n_head, 0); cb(q_nope, "q_nope", il); - // and {n_head * n_embd_head_qk_rope, n_tokens} - ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, - ggml_row_size(q->type, hparams.n_embd_head_k), - ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + // and {n_embd_head_qk_rope, n_head, n_tokens} + ggml_tensor * q_pe = ggml_view_3d(ctx0, q, + n_embd_head_qk_rope, n_head, n_tokens, + ggml_row_size(q->type, n_embd_head_k), + ggml_row_size(q->type, n_embd_head_k) * n_head, ggml_row_size(q->type, n_embd_head_qk_nope)); cb(q_pe, "q_pe", il); - // {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens} - ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); - cb(kv_pe_compresseed, "kv_pe_compresseed", il); + ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(kv_cmpr_pe, "kv_cmpr_pe", il); // split into {kv_lora_rank, n_tokens} - ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens, - kv_pe_compresseed->nb[1], + ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_cmpr_pe, + kv_lora_rank, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0); - cb(kv_compressed, "kv_compressed", il); + cb(kv_cmpr, "kv_cmpr", il); - // and {n_embd_head_qk_rope, n_tokens} - ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens, - kv_pe_compresseed->nb[1], - kv_pe_compresseed->nb[1], - ggml_row_size(kv_pe_compresseed->type, kv_lora_rank)); + // and {n_embd_head_qk_rope, 1, n_tokens} + ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, + n_embd_head_qk_rope, 1, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank)); cb(k_pe, "k_pe", il); - // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont - kv_compressed = ggml_cont(ctx0, kv_compressed); - kv_compressed = build_norm(kv_compressed, - model.layers[il].attn_kv_a_norm, NULL, - LLM_NORM_RMS, il); - cb(kv_compressed, "kv_compressed", il); - - // {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens} - ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed); - cb(kv, "kv", il); - - // split into {n_head * n_embd_head_qk_nope, n_tokens} - ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), - ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), - 0); - cb(k_nope, "k_nope", il); - - // and {n_head * n_embd_head_v, n_tokens} - ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), - ggml_row_size(kv->type, (n_embd_head_qk_nope))); - cb(v_states, "v_states", il); - - v_states = ggml_cont(ctx0, v_states); - cb(v_states, "v_states", il); - - v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, - ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), - 0); - cb(v_states, "v_states", il); - - q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this - q_pe = ggml_rope_ext( - ctx0, q_pe, inp_pos, nullptr, + q_pe = ggml_rope_ext(ctx0, q_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); + ext_factor, attn_factor, beta_fast, beta_slow + ); cb(q_pe, "q_pe", il); - // shared RoPE key - k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this - k_pe = ggml_rope_ext( - ctx0, k_pe, inp_pos, nullptr, + k_pe = ggml_rope_ext(ctx0, k_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor_scaled, beta_fast, beta_slow - ); + ext_factor, attn_factor, beta_fast, beta_slow + ); cb(k_pe, "k_pe", il); - ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0); - cb(q_states, "q_states", il); + kv_cmpr = build_norm(kv_cmpr, + model.layers[il].attn_kv_a_norm, nullptr, + LLM_NORM_RMS, il); + cb(kv_cmpr, "kv_cmpr", il); - ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); - cb(k_states, "k_states", il); + if (is_mla) { + // {n_embd_head_qk_nope, n_tokens, n_head} + q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); + cb(q_nope, "q_nope_perm", il); - cur = build_attn(inp_attn, gf, - model.layers[il].wo, NULL, - q_states, k_states, v_states, nullptr, kq_scale, il); + // {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head} + ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope); + cb(q_nope_absorbed, "q_nope_absorbed", il); + + // {kv_lora_rank, n_head, n_tokens} + q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3); + cb(q_nope_absorbed, "q_nope_absorbed_perm", il); + + // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens} + // note: rope must go first for in-place context shifting in build_rope_shift() + ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0); + cb(Qcur, "Qcur", il); + + kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens); + cb(kv_cmpr, "kv_cmpr_reshape", il); + + // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens} + ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0); + cb(Kcur, "Kcur", il); + + // {kv_lora_rank, 1, n_tokens} + ggml_tensor * Vcur = kv_cmpr; + cb(Vcur, "Vcur", il); + + // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group) + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, model.layers[il].wv_b, kq_scale, il); + } else { + ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr); + cb(kv, "kv", il); + + // split into {n_embd_head_qk_nope, n_head, n_tokens} + ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, + n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v), + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v) * n_head, + 0); + cb(k_nope, "k_nope_view", il); + + // and {n_embd_head_v, n_head, n_tokens} + ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, + n_embd_head_v, n_head, n_tokens, + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v), + ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v) * n_head, + ggml_row_size(kv->type, n_embd_head_qk_nope)); + cb(Vcur, "Vcur_view", il); + + Vcur = ggml_cont(ctx0, Vcur); + cb(Vcur, "Vcur_cont", il); + + // note: rope must go first for in-place context shifting in build_rope_shift() + ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope, 0); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0); + cb(Kcur, "Kcur", il); + + // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups) + cur = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + } } if (il == n_layer - 1) { @@ -10323,7 +10878,7 @@ struct llm_build_bitnet : public llm_graph_context { cur = build_attn(inp_attn, gf, NULL, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cur = build_norm(cur, model.layers[il].attn_sub_norm, NULL, @@ -10446,7 +11001,7 @@ struct llm_build_t5_enc : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo_enc, nullptr, - Qcur, Kcur, Vcur, kq_b, 1.0f, il); + Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il); cb(cur, "kqv_out", il); } @@ -10552,7 +11107,7 @@ struct llm_build_t5_dec : public llm_graph_context { cur = build_attn(inp_attn_self, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, kq_b, 1.0f, il); + Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il); cb(cur, "kqv_out", il); } @@ -10584,7 +11139,7 @@ struct llm_build_t5_dec : public llm_graph_context { cur = build_attn(inp_attn_cross, gf, model.layers[il].wo_cross, nullptr, - Qcur, Kcur, Vcur, nullptr, 1.0f, il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); cb(cur, "kqv_out", il); //ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); @@ -10717,7 +11272,7 @@ struct llm_build_jais : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/float(n_embd_head), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/float(n_embd_head), il); } if (il == n_layer - 1) { @@ -10849,7 +11404,7 @@ struct llm_build_chatglm : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -10982,7 +11537,7 @@ struct llm_build_glm4 : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -11126,7 +11681,7 @@ struct llm_build_nemotron : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -11211,7 +11766,7 @@ struct llm_build_exaone : public llm_graph_context { // self-attention { // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -11257,7 +11812,7 @@ struct llm_build_exaone : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -11353,10 +11908,9 @@ struct llm_build_rwkv6_base : public llm_graph_context { ggml_tensor * cur, ggml_tensor * x_prev, ggml_tensor * state_copy, - ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); const auto n_tokens = ubatch.n_tokens; const auto n_seqs = ubatch.n_seqs; @@ -11366,7 +11920,7 @@ struct llm_build_rwkv6_base : public llm_graph_context { const auto n_head = n_embd / head_size; const auto n_head_kv = hparams.n_head_kv(il); - const auto kv_head = kv_self->head; + const auto kv_head = kv_state->get_head(); const auto & layer = model.layers[il]; @@ -11477,8 +12031,8 @@ struct llm_build_rwkv6_base : public llm_graph_context { k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w)); } - ggml_tensor * wkv_state = build_copy_mask_state( - gf, kv_self->v_l[il], state_copy, state_mask, + ggml_tensor * wkv_state = build_recurrent_state( + gf, kv_state->get_v_l(il), state_copy, hparams.n_embd_v_s(), n_seqs); ggml_tensor * wkv_output; @@ -11497,9 +12051,9 @@ struct llm_build_rwkv6_base : public llm_graph_context { wkv_state, ggml_view_1d( ctx0, - kv_self->v_l[il], + kv_state->get_v_l(il), hparams.n_embd_v_s() * n_seqs, - hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il]) + hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il)) ) ) ); @@ -11534,7 +12088,6 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); ggml_tensor * state_copy = build_inp_s_copy(); - ggml_tensor * state_mask = build_inp_s_mask(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -11545,7 +12098,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, state_mask, ubatch, il + gf, state_copy, ubatch, il ); ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); @@ -11561,7 +12114,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { 1 ); - cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, state_mask, ubatch, il); + cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il); ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); @@ -11632,7 +12185,6 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { inpL = build_inp_embd(model.tok_embd); ggml_tensor * state_copy = build_inp_s_copy(); - ggml_tensor * state_mask = build_inp_s_mask(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -11643,7 +12195,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, state_mask, ubatch, il + gf, state_copy, ubatch, il ); ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il); @@ -11656,7 +12208,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { 1 ); - cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, state_mask, ubatch, il); + cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il); token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)); ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il)); @@ -11748,11 +12300,10 @@ struct llm_build_rwkv7_base : public llm_graph_context { ggml_tensor * cur, ggml_tensor * x_prev, ggml_tensor * state_copy, - ggml_tensor * state_mask, ggml_tensor *& first_layer_value, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const auto * kv_state = static_cast(mstate); const auto n_tokens = ubatch.n_tokens; const auto n_seqs = ubatch.n_seqs; @@ -11761,7 +12312,7 @@ struct llm_build_rwkv7_base : public llm_graph_context { const auto head_count = n_embd / head_size; const auto n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = kv_self->head; + const auto kv_head = kv_state->get_head(); const auto & layer = model.layers[il]; @@ -11831,8 +12382,8 @@ struct llm_build_rwkv7_base : public llm_graph_context { v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens); a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens); - ggml_tensor * wkv_state = build_copy_mask_state( - gf, kv_self->v_l[il], state_copy, state_mask, + ggml_tensor * wkv_state = build_recurrent_state( + gf, kv_state->get_v_l(il), state_copy, hparams.n_embd_v_s(), n_seqs); ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state); @@ -11846,9 +12397,9 @@ struct llm_build_rwkv7_base : public llm_graph_context { wkv_state, ggml_view_1d( ctx0, - kv_self->v_l[il], + kv_state->get_v_l(il), hparams.n_embd_v_s() * n_seqs, - hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self->v_l[il]) + hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il)) ) ) ); @@ -11890,7 +12441,6 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); ggml_tensor * state_copy = build_inp_s_copy(); - ggml_tensor * state_mask = build_inp_s_mask(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -11901,7 +12451,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, state_mask, ubatch, il + gf, state_copy, ubatch, il ); ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); @@ -11917,7 +12467,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { 1 ); - cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il); + cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il); ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); @@ -11984,7 +12534,6 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { inpL = build_inp_embd(model.tok_embd); ggml_tensor * state_copy = build_inp_s_copy(); - ggml_tensor * state_mask = build_inp_s_mask(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -11995,7 +12544,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, state_mask, ubatch, il + gf, state_copy, ubatch, il ); ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il); @@ -12008,7 +12557,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { 1 ); - cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il); + cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il); token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)); ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il)); @@ -12061,6 +12610,194 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { } }; + +struct llm_build_granite : public llm_graph_context { + llm_build_granite( + const llama_model & model, + const llm_graph_params & params, + ggml_cgraph * gf, + const bool use_rope = true) + : llm_graph_context(params) { + + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - built only if rope enabled + ggml_tensor * inp_pos = nullptr; + if (use_rope) { + inp_pos = build_inp_pos(); + } + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and (optionally) RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + if (use_rope) { + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // For Granite architectures - scale residual + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network (non-MoE) + if (model.layers[il].ffn_gate_inp == nullptr) { + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + } else { + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } else { + cur = moe_out; + } + } + + // For Granite architectures - scale residual + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + // For Granite architectures - scale logits + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + // ref: https://github.com/facebookresearch/chameleon // based on the original build_llama() function, changes: // * qk-norm @@ -12159,7 +12896,7 @@ struct llm_build_chameleon : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, nullptr, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); if (hparams.swin_norm) { cur = build_norm(cur, @@ -12515,7 +13252,7 @@ struct llm_build_plm : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, - q_states, k_states, v_states, nullptr, kq_scale, il); + q_states, k_states, v_states, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1) { @@ -12592,7 +13329,7 @@ struct llm_build_bailingmoe : public llm_graph_context { // self-attention { // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -12638,7 +13375,7 @@ struct llm_build_bailingmoe : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_rot)), il); + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il); } if (il == n_layer - 1) { @@ -12712,36 +13449,356 @@ struct llm_build_bailingmoe : public llm_graph_context { } }; -llama_memory_i * llama_model::create_memory() const { +struct llm_build_dots1 : public llm_graph_context { + llm_build_dots1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + if ((uint32_t) il < hparams.n_layer_dense_lead) { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(moe_out, "ffn_moe_out", il); + + { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_arcee : public llm_graph_context { + llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + // ARCEE uses relu^2 instead of silu + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { llama_memory_i * res; switch (arch) { + case LLM_ARCH_BERT: + case LLM_ARCH_JINA_BERT_V2: + case LLM_ARCH_NOMIC_BERT: + case LLM_ARCH_NOMIC_BERT_MOE: + case LLM_ARCH_NEO_BERT: + case LLM_ARCH_WAVTOKENIZER_DEC: + { + res = nullptr; + } break; case LLM_ARCH_MAMBA: case LLM_ARCH_RWKV6: case LLM_ARCH_RWKV6QWEN2: case LLM_ARCH_RWKV7: case LLM_ARCH_ARWKV7: { - res = new llama_kv_cache_unified(hparams, { - /*.get_rope_factors =*/ nullptr - }); + res = new llama_kv_cache_recurrent( + *this, + GGML_TYPE_F32, + GGML_TYPE_F32, + cparams.offload_kqv, + std::max((uint32_t) 1, cparams.n_seq_max), + cparams.n_seq_max); } break; default: { - res = new llama_kv_cache_unified(hparams, { - /*.get_rope_factors =*/ [this](uint32_t n_ctx_per_seq, int il) { - // choose long/short freq factors based on the context size - if (layers[il].rope_freqs != nullptr) { - return layers[il].rope_freqs; - } + const auto padding = llama_kv_cache_unified::get_padding(cparams); - if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) { - return layers[il].rope_long; - } + cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); - return layers[il].rope_short; - } - }); + LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); + + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + GGML_ASSERT(hparams.is_swa_any()); + + res = new llama_kv_cache_unified_iswa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + params.swa_full, + cparams.n_ctx, + cparams.n_seq_max, + cparams.n_ubatch, + padding); + } else { + GGML_ASSERT(!hparams.is_swa_any()); + + res = new llama_kv_cache_unified( + *this, + nullptr, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + cparams.n_ctx, + cparams.n_seq_max, + padding, + hparams.n_swa, + hparams.swa_type); + } } } @@ -12756,13 +13813,13 @@ llm_graph_result_ptr llama_model::build_graph( switch (arch) { case LLM_ARCH_LLAMA: - case LLM_ARCH_LLAMA4: - case LLM_ARCH_MINICPM: - case LLM_ARCH_GRANITE: - case LLM_ARCH_GRANITE_MOE: { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_LLAMA4: + { + llm = std::make_unique(*this, params, gf); + } break; case LLM_ARCH_DECI: { llm = std::make_unique(*this, params, gf); @@ -12790,9 +13847,14 @@ llm_graph_result_ptr llama_model::build_graph( case LLM_ARCH_BERT: case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_NOMIC_BERT: + case LLM_ARCH_NOMIC_BERT_MOE: { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_NEO_BERT: + { + llm = std::make_unique(*this, params, gf); + } break; case LLM_ARCH_BLOOM: { llm = std::make_unique(*this, params, gf); @@ -12836,7 +13898,11 @@ llm_graph_result_ptr llama_model::build_graph( case LLM_ARCH_PHI3: case LLM_ARCH_PHIMOE: { - llm = std::make_unique(*this, params, gf); + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + llm = std::make_unique> (*this, params, gf); + } else { + llm = std::make_unique>(*this, params, gf); + } } break; case LLM_ARCH_PLAMO: { @@ -12868,11 +13934,11 @@ llm_graph_result_ptr llama_model::build_graph( } break; case LLM_ARCH_GEMMA2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params, gf); } break; case LLM_ARCH_GEMMA3: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params, gf); } break; case LLM_ARCH_STARCODER2: { @@ -12892,7 +13958,7 @@ llm_graph_result_ptr llama_model::build_graph( } break; case LLM_ARCH_COHERE2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params, gf); } break; case LLM_ARCH_DBRX: { @@ -12989,6 +14055,12 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_GRANITE: + case LLM_ARCH_GRANITE_MOE: + case LLM_ARCH_MINICPM: + { + llm = std::make_unique(*this, params, gf); + } break; case LLM_ARCH_CHAMELEON: { llm = std::make_unique(*this, params, gf); @@ -13005,6 +14077,14 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_DOTS1: + { + llm = std::make_unique(*this, params, gf); + } break; + case LLM_ARCH_ARCEE: + { + llm = std::make_unique(*this, params, gf); + } break; default: GGML_ABORT("fatal error"); } @@ -13076,6 +14156,22 @@ int32_t llama_model_n_head_kv(const llama_model * model) { return model->hparams.n_head_kv(); } +int32_t llama_model_n_swa(const llama_model * model) { + return model->hparams.n_swa; +} + +uint32_t llama_model_n_cls_out(const struct llama_model * model) { + return model->hparams.n_cls_out; +} + +const char * llama_model_cls_label(const struct llama_model * model, uint32_t i) { + if (i < model->classifier_labels.size()) { + return model->classifier_labels[i].c_str(); + } + + return nullptr; +} + // deprecated int32_t llama_n_ctx_train(const llama_model * model) { return llama_model_n_ctx_train(model); @@ -13122,8 +14218,6 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_DECI: case LLM_ARCH_BAICHUAN: case LLM_ARCH_STARCODER: - case LLM_ARCH_PLAMO: - case LLM_ARCH_ORION: case LLM_ARCH_INTERNLM2: case LLM_ARCH_MINICPM: case LLM_ARCH_XVERSE: @@ -13140,6 +14234,8 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_CHAMELEON: case LLM_ARCH_BAILINGMOE: + case LLM_ARCH_NEO_BERT: + case LLM_ARCH_ARCEE: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 @@ -13148,6 +14244,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_DBRX: case LLM_ARCH_BERT: case LLM_ARCH_NOMIC_BERT: + case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_STABLELM: case LLM_ARCH_BITNET: case LLM_ARCH_QWEN: @@ -13160,6 +14257,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_PHI2: case LLM_ARCH_PHI3: case LLM_ARCH_PHIMOE: + case LLM_ARCH_PLAMO: case LLM_ARCH_GEMMA: case LLM_ARCH_GEMMA2: case LLM_ARCH_GEMMA3: @@ -13167,9 +14265,11 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_OPENELM: case LLM_ARCH_GPTNEOX: case LLM_ARCH_CODESHELL: + case LLM_ARCH_ORION: case LLM_ARCH_NEMOTRON: case LLM_ARCH_EXAONE: case LLM_ARCH_MINICPM3: + case LLM_ARCH_DOTS1: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: @@ -13235,10 +14335,18 @@ uint64_t llama_model_size(const llama_model * model) { } const char * llama_model_chat_template(const llama_model * model, const char * name) { - const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N) + const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE) : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE); const auto & it = model->gguf_kv.find(key); if (it == model->gguf_kv.end()) { + // one-off fix for very popular models (so we are not flooded with issues) + // do not extend this list unless absolutely necessary + // Mistral-Small-2503 does not have built-in chat template + llama_vocab_pre_type pre_type = model->vocab.get_pre_type(); + if (pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) { + return "mistral-v7-tekken"; + } + return nullptr; } diff --git a/src/llama-model.h b/src/llama-model.h index 0f18dac16..06e6c6879 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -36,14 +36,17 @@ enum llm_type { LLM_TYPE_335M, LLM_TYPE_410M, LLM_TYPE_450M, + LLM_TYPE_475M, LLM_TYPE_770M, LLM_TYPE_780M, LLM_TYPE_0_5B, + LLM_TYPE_0_6B, LLM_TYPE_1B, LLM_TYPE_1_3B, LLM_TYPE_1_4B, LLM_TYPE_1_5B, LLM_TYPE_1_6B, + LLM_TYPE_1_7B, LLM_TYPE_1_8B, LLM_TYPE_2B, LLM_TYPE_2_8B, @@ -62,6 +65,7 @@ enum llm_type { LLM_TYPE_15B, LLM_TYPE_16B, LLM_TYPE_20B, + LLM_TYPE_27B, LLM_TYPE_30B, LLM_TYPE_32B, LLM_TYPE_34B, @@ -69,8 +73,11 @@ enum llm_type { LLM_TYPE_40B, LLM_TYPE_65B, LLM_TYPE_70B, + LLM_TYPE_142B, LLM_TYPE_236B, + LLM_TYPE_290B, LLM_TYPE_314B, + LLM_TYPE_405B, LLM_TYPE_671B, LLM_TYPE_SMALL, LLM_TYPE_MEDIUM, @@ -84,12 +91,14 @@ enum llm_type { LLM_TYPE_16x3_8B, LLM_TYPE_10B_128x3_66B, LLM_TYPE_57B_A14B, - LLM_TYPE_27B, - LLM_TYPE_290B, LLM_TYPE_17B_16E, // llama4 Scout LLM_TYPE_17B_128E, // llama4 Maverick + LLM_TYPE_30B_A3B, + LLM_TYPE_235B_A22B, }; +std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type); + struct llama_layer_posnet { // resnet struct ggml_tensor * norm1 = nullptr; @@ -171,6 +180,8 @@ struct llama_layer { struct ggml_tensor * wq_b = nullptr; struct ggml_tensor * wkv_a_mqa = nullptr; struct ggml_tensor * wkv_b = nullptr; + struct ggml_tensor * wk_b = nullptr; + struct ggml_tensor * wv_b = nullptr; struct ggml_tensor * wq_cross = nullptr; struct ggml_tensor * wk_cross = nullptr; struct ggml_tensor * wv_cross = nullptr; @@ -319,6 +330,9 @@ struct llama_model { llama_hparams hparams = {}; llama_vocab vocab; + // for classifier models + std::vector classifier_labels; + struct ggml_tensor * tok_embd = nullptr; struct ggml_tensor * type_embd = nullptr; struct ggml_tensor * pos_embd = nullptr; @@ -388,8 +402,14 @@ struct llama_model { const struct ggml_tensor * get_tensor(const char * name) const; + float get_rope_freq_base (const llama_cparams & cparams, int il) const; + float get_rope_freq_scale(const llama_cparams & cparams, int il) const; + + ggml_tensor * get_rope_factors(const llama_cparams & cparams, int il) const; + + // note: can mutate `cparams` // TODO: move this to new llm_arch_model_i interface - llama_memory_i * create_memory() const; // TODO: params + llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const; // TODO: move this to new llm_arch_model_i interface llm_graph_result_ptr build_graph( diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index e3e10fa6c..8cf45732f 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -10,9 +10,16 @@ #include #include #include +#include #include #include +// Quantization types. Changes to this struct must be replicated in quantize.cpp +struct tensor_quantization { + std::string name; + ggml_type quant = GGML_TYPE_COUNT; +}; + static void zeros(std::ofstream & file, size_t n) { char zero = 0; for (size_t i = 0; i < n; ++i) { @@ -48,7 +55,7 @@ struct quantize_state_impl { }; static void llama_tensor_dequantize_impl( - struct ggml_tensor * tensor, std::vector> & output, std::vector & workers, + ggml_tensor * tensor, std::vector> & output, std::vector & workers, const size_t nelements, const int nthread ) { if (output.size() < nelements) { @@ -512,7 +519,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: nthread = std::thread::hardware_concurrency(); } - // mmap consistently increases speed Linux, and also increases speed on Windows with + // mmap consistently increases speed on Linux, and also increases speed on Windows with // hot cache. It may cause a slowdown on macOS, possibly related to free memory. #if defined(__linux__) || defined(_WIN32) constexpr bool use_mmap = true; @@ -522,7 +529,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: llama_model_kv_override * kv_overrides = nullptr; if (params->kv_overrides) { - auto v = (std::vector*)params->kv_overrides; + auto * v = (std::vector*)params->kv_overrides; kv_overrides = v->data(); } @@ -536,7 +543,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: model.load_hparams(ml); model.load_stats (ml); - struct quantize_state_impl qs(model, params); + quantize_state_impl qs(model, params); if (params->only_copy) { ftype = ml.ftype; @@ -578,7 +585,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) { gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64); } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) { - gguf_set_val_i32(ctx_out.get(), o.key, o.val_i64); + // Setting type to UINT32. See https://github.com/ggml-org/llama.cpp/pull/14182 for context + gguf_set_val_u32(ctx_out.get(), o.key, (uint32_t)abs(o.val_i64)); } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) { gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool); } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) { @@ -661,7 +669,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // populate the original tensors so we get an initial meta data for (const auto * it : tensors) { uint16_t i_split = params->keep_split ? it->idx : 0; - struct ggml_tensor * tensor = it->tensor; + ggml_tensor * tensor = it->tensor; if (!ctx_outs[i_split]) { ctx_outs[i_split].reset(gguf_init_empty()); } @@ -710,7 +718,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: new_ofstream(0); for (const auto * it : tensors) { const auto & weight = *it; - struct ggml_tensor * tensor = weight.tensor; + ggml_tensor * tensor = weight.tensor; if (weight.idx != cur_split && params->keep_split) { close_ofstream(); new_ofstream(weight.idx); @@ -776,7 +784,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // do not quantize relative position bias (T5) quantize &= name.find("attn_rel_b.weight") == std::string::npos; - enum ggml_type new_type; + ggml_type new_type; void * new_data; size_t new_size; @@ -786,7 +794,22 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // get more optimal quantization type based on the tensor shape, layer, etc. if (!params->pure && ggml_is_quantized(default_type)) { new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); + // unless the user specifies a type + if (params->tensor_types) { + const std::vector & tensor_types = *static_cast *>(params->tensor_types); + const std::string tensor_name(tensor->name); + for (const auto & [tname, qtype] : tensor_types) { + if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) { + if (qtype != new_type) { + LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type)); + new_type = qtype; + break; // if two or more types are specified for the tensor, first match wins + } + } + } + } } + if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) { new_type = params->token_embedding_type; } @@ -910,8 +933,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // interface implementation // -struct llama_model_quantize_params llama_model_quantize_default_params() { - struct llama_model_quantize_params result = { +llama_model_quantize_params llama_model_quantize_default_params() { + llama_model_quantize_params result = { /*.nthread =*/ 0, /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1, /*.output_tensor_type =*/ GGML_TYPE_COUNT, @@ -923,6 +946,7 @@ struct llama_model_quantize_params llama_model_quantize_default_params() { /*.keep_split =*/ false, /*.imatrix =*/ nullptr, /*.kv_overrides =*/ nullptr, + /*.tensor_type =*/ nullptr, }; return result; diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index d14979850..bfbf5fa23 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -232,7 +232,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) // } if (k <= 0) { - k = cur_p->size; + return; } k = std::min(k, (int) cur_p->size); @@ -298,6 +298,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) } cur_p->sorted = true; } + cur_p->size = k; } @@ -797,7 +798,7 @@ static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_d } // if we have enough values the operation was a success - if (filtered_tokens.size() >= ctx->min_keep) { + if (!filtered_tokens.empty() && filtered_tokens.size() >= ctx->min_keep) { memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data)); cur_p->size = filtered_tokens.size(); min_p_applied = true; @@ -908,7 +909,7 @@ static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token cum_sum += cur_p->data[idx].p; // Check if the running sum is greater than typical or if we have kept at least min_keep tokens - if (cum_sum > ctx->p && i >= ctx->min_keep - 1) { + if (cum_sum > ctx->p && (ctx->min_keep == 0 || i >= ctx->min_keep - 1)) { last_idx = i + 1; break; } @@ -1749,23 +1750,35 @@ static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx; + if (ctx->n <= 0.0f || cur_p->size <= 1) { + return; + } + // find max logit and calculate mean float max = cur_p->data[0].logit; float logits_sum = 0; + size_t valid_count = 0; for (size_t i = 0; i < cur_p->size; ++i) { - if (cur_p->data[i].logit > max) { - max = cur_p->data[i].logit; + // Only count non-negative infinity values + if (cur_p->data[i].logit != -INFINITY) { + if (cur_p->data[i].logit > max) { + max = cur_p->data[i].logit; + } + logits_sum += cur_p->data[i].logit; + valid_count++; } - logits_sum += cur_p->data[i].logit; } - float mean = logits_sum/cur_p->size; + float mean = valid_count > 0 ? logits_sum/valid_count : 0; // calculate standard deviation float acc = 0; for (size_t i = 0; i < cur_p->size; ++i) { - acc += pow(cur_p->data[i].logit - mean, 2); + // Skip -infinity in std calculation + if (cur_p->data[i].logit != -INFINITY) { + acc += pow(cur_p->data[i].logit - mean, 2); + } } - float std = sqrt(acc/cur_p->size); + float std = valid_count > 0 ? sqrt(acc/valid_count) : 0; //apply mask for (size_t i = 0; i < cur_p->size; ++i) { diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 464ff01e0..dd2251ef3 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1,5 +1,7 @@ #include "llama-vocab.h" +#include "ggml.h" +#include "gguf.h" #include "llama-impl.h" #include "llama-model-loader.h" @@ -7,16 +9,16 @@ #include #include +#include #include -#include #include #include #include +#include #include #include #include #include -#include // // helpers @@ -415,6 +417,13 @@ struct llm_tokenizer_bpe : llm_tokenizer { "'(?:[sSdDmMtT]|[lL][lL]|[vV][eE]|[rR][eE])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_SEED_CODER: + regex_exprs = { + // original regex from tokenizer.json + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\r\n]+|\\s*[\r\n]+|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\\r\\n]+|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -826,7 +835,7 @@ struct llm_tokenizer_ugm_session { } // initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores - std::vector tokenization_results(input_len + 1, {vocab.token_unk(), 0, -FLT_MAX}); + std::vector tokenization_results(input_len + 1, {vocab.token_unk(), 0, -DBL_MAX}); // at the beginning tokenization score is zero tokenization_results[0] = { vocab.token_unk(), 0, 0 }; @@ -858,7 +867,7 @@ struct llm_tokenizer_ugm_session { const double challenger_score = current_best.score_sum + token_score; struct best_tokenization & current_champ = tokenization_results[prefix_offset]; if (challenger_score > current_champ.score_sum) { - struct best_tokenization challenger = { token_id, input_offset, (float) challenger_score }; + struct best_tokenization challenger = { token_id, input_offset, challenger_score }; current_champ = challenger; } } @@ -872,7 +881,7 @@ struct llm_tokenizer_ugm_session { prefix_offset = input_offset + n_utf8_code_units; struct best_tokenization & current_champ = tokenization_results[prefix_offset]; if (challenger_score > current_champ.score_sum) { - struct best_tokenization challenger = { vocab.token_unk(), input_offset, (float) challenger_score }; + struct best_tokenization challenger = { vocab.token_unk(), input_offset, challenger_score }; current_champ = challenger; } } @@ -998,7 +1007,7 @@ private: struct best_tokenization { llama_token token_id; size_t input_offset; - float score_sum; + double score_sum; }; struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) { @@ -1227,6 +1236,9 @@ struct fragment_buffer_variant { struct llama_vocab::impl { uint32_t n_token_types = 0; // for BERT-style token types + std::string tokenizer_model; + std::string tokenizer_pre; + enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; enum llama_vocab_pre_type pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; @@ -1362,9 +1374,6 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { // determine vocab type { - std::string tokenizer_model; - std::string tokenizer_pre; - ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model); ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false); @@ -1459,7 +1468,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str()); if (precompiled_charsmap_keyidx != -1) { - size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx); + const gguf_type pc_type = gguf_get_arr_type(ctx, precompiled_charsmap_keyidx); + GGML_ASSERT(pc_type == GGUF_TYPE_INT8 || pc_type == GGUF_TYPE_UINT8); + + const size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx); const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx); precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap); #ifdef IS_BIG_ENDIAN @@ -1506,7 +1518,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "llama3" || tokenizer_pre == "llama-v3" || tokenizer_pre == "llama-bpe"|| - tokenizer_pre == "falcon3") { + tokenizer_pre == "falcon3" || + tokenizer_pre == "pixtral") { pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3; ignore_merges = true; add_bos = true; @@ -1633,6 +1646,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "bailingmoe") { pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE; clean_spaces = false; + } else if ( + tokenizer_pre == "seed-coder") { + pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER; + clean_spaces = false; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -1841,6 +1858,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { if (false || t.first == "<|fim_prefix|>" // Qwen || t.first == "" + || t.first == "" // Granite || t.first == "<|fim▁begin|>" // DeepSeek || t.first == "
"
                         || t.first == "▁
"          // CodeLlama
@@ -1859,6 +1877,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                 if (false
                         || t.first == "<|fim_suffix|>" // Qwen
                         || t.first == ""
+                        || t.first == ""   // Granite
                         || t.first == "<|fim▁hole|>" // DeepSeek
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
@@ -1877,6 +1896,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                 if (false
                         || t.first == "<|fim_middle|>" // Qwen
                         || t.first == ""
+                        || t.first == ""   // Granite
                         || t.first == "<|fim▁end|>"  // DeepSeek
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
@@ -1895,6 +1915,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                 if (false
                         || t.first == "<|fim_pad|>" // Qwen
                         || t.first == ""
+                        || t.first == ""   // Granite
                         || t.first == ""
                         ) {
                     special_fim_pad_id = t.second;
@@ -1913,6 +1934,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<|repo_name|>"
                         || t.first == ""
                         || t.first == ""
+                        || t.first == ""    // Granite
                         ) {
                     special_fim_rep_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -1965,6 +1987,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                     || t.first == "<|eom_id|>"
                     || t.first == ""
                     || t.first == "_"
+                    || t.first == "<|end_of_text|>"
                ) {
                 special_eog_ids.insert(t.second);
                 if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2058,9 +2081,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
 
         std::string model_name;
         std::string tokenizer_pre;
+        std::string general_arch;
 
         ml.get_key(LLM_KV_GENERAL_NAME,  model_name,    false);
         ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
+        ml.get_key(LLM_KV_GENERAL_ARCHITECTURE, general_arch, false);
 
         // model name to lowercase
         std::transform(model_name.begin(), model_name.end(), model_name.begin(),
@@ -2069,9 +2094,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
             }
         );
 
-        // set attributes by model/tokenizer name
-        if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) {
-            _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true);
+        // set attributes by model/tokenizer/architecture name
+        if (false
+                || _contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})
+                || _contains_any(general_arch, {"nomic-bert-moe"})
+           ) {
+            if (token_to_id.count("") == 0) {
+                LLAMA_LOG_WARN("%s: Mask token is missing in vocab, please reconvert model!\n", __func__);
+            } else {
+                _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true);
+            }
         } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
             for (auto id : cache_special_tokens) {
                 _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
@@ -2541,6 +2573,10 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t
     // copy piece chars to output text buffer
     // skip up to 'lstrip' leading spaces before copying
     auto _try_copy = [=] (const char * token, size_t size) -> int32_t {
+        if (size >= static_cast(std::numeric_limits::max())) {
+            GGML_ABORT("invalid token size: %zu exceeds int32_t limit", size);
+        }
+
         for (int32_t i = 0; i < lstrip && size && *token == ' '; ++i) {
             token++;
             size--;
@@ -2737,26 +2773,26 @@ void llama_vocab::impl::print_info() const {
     LLAMA_LOG_INFO("%s: n_merges         = %u\n",     __func__, (uint32_t) bpe_ranks.size());
 
     // special tokens
-    if (special_bos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, special_bos_id,     id_to_token[special_bos_id].text.c_str() );  }
-    if (special_eos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, special_eos_id,     id_to_token[special_eos_id].text.c_str() );  }
-    if (special_eot_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, special_eot_id,     id_to_token[special_eot_id].text.c_str() );  }
-    if (special_eom_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOM token        = %d '%s'\n", __func__, special_eom_id,     id_to_token[special_eom_id].text.c_str() );  }
-    if (special_unk_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, special_unk_id,     id_to_token[special_unk_id].text.c_str() );  }
-    if (special_sep_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, special_sep_id,     id_to_token[special_sep_id].text.c_str() );  }
-    if (special_pad_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, special_pad_id,     id_to_token[special_pad_id].text.c_str() );  }
-    if (special_mask_id != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, special_mask_id,    id_to_token[special_mask_id].text.c_str() ); }
+    if (special_bos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, special_bos_id,     id_to_token.at(special_bos_id).text.c_str() );  }
+    if (special_eos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, special_eos_id,     id_to_token.at(special_eos_id).text.c_str() );  }
+    if (special_eot_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, special_eot_id,     id_to_token.at(special_eot_id).text.c_str() );  }
+    if (special_eom_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOM token        = %d '%s'\n", __func__, special_eom_id,     id_to_token.at(special_eom_id).text.c_str() );  }
+    if (special_unk_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, special_unk_id,     id_to_token.at(special_unk_id).text.c_str() );  }
+    if (special_sep_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, special_sep_id,     id_to_token.at(special_sep_id).text.c_str() );  }
+    if (special_pad_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, special_pad_id,     id_to_token.at(special_pad_id).text.c_str() );  }
+    if (special_mask_id != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, special_mask_id,    id_to_token.at(special_mask_id).text.c_str() ); }
 
-    if (linefeed_id != LLAMA_TOKEN_NULL)        { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, linefeed_id,        id_to_token[linefeed_id].text.c_str() ); }
+    if (linefeed_id != LLAMA_TOKEN_NULL)        { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, linefeed_id,        id_to_token.at(linefeed_id).text.c_str() ); }
 
-    if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token    = %d '%s'\n", __func__, special_fim_pre_id, id_to_token[special_fim_pre_id].text.c_str() ); }
-    if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token    = %d '%s'\n", __func__, special_fim_suf_id, id_to_token[special_fim_suf_id].text.c_str() ); }
-    if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token    = %d '%s'\n", __func__, special_fim_mid_id, id_to_token[special_fim_mid_id].text.c_str() ); }
-    if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token    = %d '%s'\n", __func__, special_fim_pad_id, id_to_token[special_fim_pad_id].text.c_str() ); }
-    if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token    = %d '%s'\n", __func__, special_fim_rep_id, id_to_token[special_fim_rep_id].text.c_str() ); }
-    if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token    = %d '%s'\n", __func__, special_fim_sep_id, id_to_token[special_fim_sep_id].text.c_str() ); }
+    if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token    = %d '%s'\n", __func__, special_fim_pre_id, id_to_token.at(special_fim_pre_id).text.c_str() ); }
+    if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token    = %d '%s'\n", __func__, special_fim_suf_id, id_to_token.at(special_fim_suf_id).text.c_str() ); }
+    if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token    = %d '%s'\n", __func__, special_fim_mid_id, id_to_token.at(special_fim_mid_id).text.c_str() ); }
+    if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token    = %d '%s'\n", __func__, special_fim_pad_id, id_to_token.at(special_fim_pad_id).text.c_str() ); }
+    if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token    = %d '%s'\n", __func__, special_fim_rep_id, id_to_token.at(special_fim_rep_id).text.c_str() ); }
+    if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token    = %d '%s'\n", __func__, special_fim_sep_id, id_to_token.at(special_fim_sep_id).text.c_str() ); }
 
     for (const auto & id : special_eog_ids) {
-        LLAMA_LOG_INFO( "%s: EOG token        = %d '%s'\n", __func__, id, id_to_token[id].text.c_str() );
+        LLAMA_LOG_INFO( "%s: EOG token        = %d '%s'\n", __func__, id, id_to_token.at(id).text.c_str() );
     }
 
     LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, max_token_len);
@@ -2772,6 +2808,14 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
     pimpl->load(ml, kv);
 }
 
+std::string llama_vocab::get_tokenizer_model() const {
+    return pimpl->tokenizer_model;
+}
+
+std::string llama_vocab::get_tokenizer_pre() const {
+    return pimpl->tokenizer_pre;
+}
+
 enum llama_vocab_type llama_vocab::get_type() const {
     return pimpl->type;
 }
@@ -2994,6 +3038,20 @@ int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string
     return it->second;
 }
 
+std::vector llama_vocab::get_bpe_merges() const {
+    std::vector result(pimpl->bpe_ranks.size());
+
+    for (const auto & pair : pimpl->bpe_ranks) {
+        result[pair.second] = pair.first.first + " " + pair.first.second;
+    }
+
+    return result;
+}
+
+std::vector llama_vocab::get_precompiled_charsmap() const {
+    return pimpl->precompiled_charsmap;
+}
+
 int32_t llama_vocab::tokenize(
                   const char * text,
                      int32_t   text_len,
diff --git a/src/llama-vocab.h b/src/llama-vocab.h
index 5ce355214..daa6cf308 100644
--- a/src/llama-vocab.h
+++ b/src/llama-vocab.h
@@ -21,6 +21,9 @@ struct llama_vocab {
 
     void load(llama_model_loader & ml, const LLM_KV & kv);
 
+    std::string get_tokenizer_model() const;
+    std::string get_tokenizer_pre() const;
+
     enum llama_vocab_type     get_type()     const;
     enum llama_vocab_pre_type get_pre_type() const;
 
@@ -80,6 +83,9 @@ struct llama_vocab {
     int max_token_len() const;
 
     int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
+    std::vector get_bpe_merges() const;
+
+    std::vector get_precompiled_charsmap() const;
 
     int32_t tokenize(
                    const char * text,
diff --git a/src/llama.cpp b/src/llama.cpp
index d5164720b..34906cdb6 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -4,6 +4,7 @@
 #include "llama-mmap.h"
 #include "llama-vocab.h"
 #include "llama-model-loader.h"
+#include "llama-model-saver.h"
 #include "llama-model.h"
 
 #include "ggml.h"
@@ -139,6 +140,11 @@ static struct llama_model * llama_model_load_from_file_impl(
         struct llama_model_params params) {
     ggml_time_init();
 
+    if (!params.vocab_only && ggml_backend_reg_count() == 0) {
+        LLAMA_LOG_ERROR("%s: no backends are loaded. hint: use ggml_backend_load() or ggml_backend_load_all() to load a backend before calling this function\n", __func__);
+        return nullptr;
+    }
+
     unsigned cur_percentage = 0;
     if (params.progress_callback == NULL) {
         params.progress_callback_user_data = &cur_percentage;
@@ -192,14 +198,18 @@ static struct llama_model * llama_model_load_from_file_impl(
 
     // if using single GPU mode, remove all except the main GPU
     if (params.split_mode == LLAMA_SPLIT_MODE_NONE) {
-        if (params.main_gpu < 0 || params.main_gpu >= (int)model->devices.size()) {
-            LLAMA_LOG_ERROR("%s: invalid value for main_gpu: %d (available devices: %d)\n", __func__, params.main_gpu, (int)model->devices.size());
-            llama_model_free(model);
-            return nullptr;
+        if (params.main_gpu < 0) {
+            model->devices.clear();
+        } else {
+            if (params.main_gpu >= (int)model->devices.size()) {
+                LLAMA_LOG_ERROR("%s: invalid value for main_gpu: %d (available devices: %zu)\n", __func__, params.main_gpu, model->devices.size());
+                llama_model_free(model);
+                return nullptr;
+            }
+            ggml_backend_dev_t main_gpu = model->devices[params.main_gpu];
+            model->devices.clear();
+            model->devices.push_back(main_gpu);
         }
-        ggml_backend_dev_t main_gpu = model->devices[params.main_gpu];
-        model->devices.clear();
-        model->devices.push_back(main_gpu);
     }
 
     for (auto * dev : model->devices) {
@@ -253,6 +263,13 @@ struct llama_model * llama_model_load_from_splits(
     return llama_model_load_from_file_impl(splits.front(), splits, params);
 }
 
+void llama_model_save_to_file(const struct llama_model * model, const char * path_model) {
+    llama_model_saver ms(*model);
+    ms.add_kv_from_model();
+    ms.add_tensors_from_model();
+    ms.save(path_model);
+}
+
 //
 // chat templates
 //
@@ -338,3 +355,4 @@ const char * llama_print_system_info(void) {
 
     return s.c_str();
 }
+
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index 2bb210702..fc1557a2d 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -1,5 +1,17 @@
 llama_add_compile_flags()
 
+function(llama_build source)
+    if (DEFINED LLAMA_TEST_NAME)
+        set(TEST_TARGET ${LLAMA_TEST_NAME})
+    else()
+        get_filename_component(TEST_TARGET ${source} NAME_WE)
+    endif()
+
+    add_executable(${TEST_TARGET} ${source})
+    target_link_libraries(${TEST_TARGET} PRIVATE common)
+    install(TARGETS ${TEST_TARGET} RUNTIME)
+endfunction()
+
 function(llama_test target)
     include(CMakeParseArguments)
     set(options)
@@ -30,13 +42,41 @@ function(llama_test target)
     set_property(TEST ${TEST_NAME} PROPERTY LABELS ${LLAMA_TEST_LABEL})
 endfunction()
 
+function(llama_test_cmd target)
+    include(CMakeParseArguments)
+    set(options)
+    set(oneValueArgs NAME LABEL WORKING_DIRECTORY)
+    set(multiValueArgs ARGS)
+    cmake_parse_arguments(LLAMA_TEST "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
+
+    if (NOT DEFINED LLAMA_TEST_LABEL)
+        set(LLAMA_TEST_LABEL "main")
+    endif()
+    if (NOT DEFINED LLAMA_TEST_WORKING_DIRECTORY)
+        set(LLAMA_TEST_WORKING_DIRECTORY .)
+    endif()
+    if (DEFINED LLAMA_TEST_NAME)
+        set(TEST_NAME ${LLAMA_TEST_NAME})
+    else()
+        set(TEST_NAME ${target})
+    endif()
+
+    add_test(
+        NAME ${TEST_NAME}
+        WORKING_DIRECTORY ${LLAMA_TEST_WORKING_DIRECTORY}
+        COMMAND ${target}
+        ${LLAMA_TEST_ARGS})
+
+    set_property(TEST ${TEST_NAME} PROPERTY LABELS ${LLAMA_TEST_LABEL})
+endfunction()
+
 # Builds and runs a test source file.
 # Optional args:
 # - NAME: name of the executable & test target (defaults to the source file name without extension)
 # - LABEL: label for the test (defaults to main)
 # - ARGS: arguments to pass to the test executable
 # - WORKING_DIRECTORY
-function(llama_target_and_test source)
+function(llama_build_and_test source)
     include(CMakeParseArguments)
     set(options)
     set(oneValueArgs NAME LABEL WORKING_DIRECTORY)
@@ -58,6 +98,7 @@ function(llama_target_and_test source)
     add_executable(${TEST_TARGET} ${source} get-model.cpp)
     install(TARGETS ${TEST_TARGET} RUNTIME)
     target_link_libraries(${TEST_TARGET} PRIVATE common)
+
     add_test(
         NAME ${TEST_TARGET}
         WORKING_DIRECTORY ${LLAMA_TEST_WORKING_DIRECTORY}
@@ -68,91 +109,108 @@ function(llama_target_and_test source)
 endfunction()
 
 # build test-tokenizer-0 target once and add many tests
-add_executable(test-tokenizer-0 test-tokenizer-0.cpp)
-target_link_libraries(test-tokenizer-0 PRIVATE common)
-install(TARGETS test-tokenizer-0 RUNTIME)
+llama_build(test-tokenizer-0.cpp)
 
-llama_test(test-tokenizer-0 NAME test-tokenizer-0-bert-bge          ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-bert-bge.gguf)
-llama_test(test-tokenizer-0 NAME test-tokenizer-0-command-r         ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-command-r.gguf)
-llama_test(test-tokenizer-0 NAME test-tokenizer-0-deepseek-coder    ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-deepseek-coder.gguf)
-llama_test(test-tokenizer-0 NAME test-tokenizer-0-deepseek-llm      ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-deepseek-llm.gguf)
-llama_test(test-tokenizer-0 NAME test-tokenizer-0-falcon            ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
-llama_test(test-tokenizer-0 NAME test-tokenizer-0-gpt-2             ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-2.gguf)
-llama_test(test-tokenizer-0 NAME test-tokenizer-0-llama-bpe         ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf)
-llama_test(test-tokenizer-0 NAME test-tokenizer-0-llama-spm         ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-spm.gguf)
-llama_test(test-tokenizer-0 NAME test-tokenizer-0-mpt               ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf)
-llama_test(test-tokenizer-0 NAME test-tokenizer-0-phi-3             ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-phi-3.gguf)
-llama_test(test-tokenizer-0 NAME test-tokenizer-0-qwen2             ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-qwen2.gguf)
-llama_test(test-tokenizer-0 NAME test-tokenizer-0-refact            ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
-llama_test(test-tokenizer-0 NAME test-tokenizer-0-starcoder         ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
-
-if (LLAMA_LLGUIDANCE)
-    llama_target_and_test(test-grammar-llguidance.cpp ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf)
-endif ()
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-bert-bge          ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-bert-bge.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-command-r         ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-command-r.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-deepseek-coder    ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-deepseek-coder.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-deepseek-llm      ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-deepseek-llm.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-falcon            ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-falcon.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-gpt-2             ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-gpt-2.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-llama-bpe         ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-llama-bpe.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-llama-spm         ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-llama-spm.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-mpt               ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-mpt.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-phi-3             ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-phi-3.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-qwen2             ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-qwen2.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-refact            ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-refact.gguf)
+llama_test(test-tokenizer-0 NAME test-tokenizer-0-starcoder         ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-starcoder.gguf)
 
 if (NOT WIN32)
-    # these tests are disabled on Windows because they use internal functions not exported with LLAMA_API
-    llama_target_and_test(test-sampling.cpp)
-    llama_target_and_test(test-grammar-parser.cpp)
-    llama_target_and_test(test-grammar-integration.cpp)
-    llama_target_and_test(test-llama-grammar.cpp)
-    llama_target_and_test(test-chat.cpp)
-    # TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8
-    if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
-        llama_target_and_test(test-json-schema-to-grammar.cpp   WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..)
-        target_include_directories(test-json-schema-to-grammar PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../examples/server)
-    endif()
-
-
-    # build test-tokenizer-1-bpe target once and add many tests
-    add_executable(test-tokenizer-1-bpe test-tokenizer-1-bpe.cpp)
-    target_link_libraries(test-tokenizer-1-bpe PRIVATE common)
-    install(TARGETS test-tokenizer-1-bpe RUNTIME)
-
-    # TODO: disabled due to slowness
-    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-aquila    ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
-    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-falcon    ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
-    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-gpt-2     ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-2.gguf)
-    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-gpt-neox  ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-neox.gguf)
-    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-llama-bpe ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf --ignore-merges)
-    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-mpt       ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf)
-    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-refact    ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
-    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-starcoder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
-
-    # build test-tokenizer-1-spm target once and add many tests
-    add_executable(test-tokenizer-1-spm test-tokenizer-1-spm.cpp)
-    target_link_libraries(test-tokenizer-1-spm PRIVATE common)
-    install(TARGETS test-tokenizer-1-spm RUNTIME)
-
-    llama_test(test-tokenizer-1-spm  NAME test-tokenizer-1-llama-spm ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-spm.gguf)
-    #llama_test(test-tokenizer-1-spm  NAME test-tokenizer-1-baichuan  ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-baichuan.gguf)
-
-    # llama_target_and_test(test-double-float.cpp) # SLOW
+    llama_test_cmd(
+        ${CMAKE_CURRENT_SOURCE_DIR}/test-tokenizers-repo.sh
+        NAME test-tokenizers-ggml-vocabs
+        WORKING_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}
+        ARGS https://huggingface.co/ggml-org/vocabs ${PROJECT_SOURCE_DIR}/models/ggml-vocabs
+    )
 endif()
 
-llama_target_and_test(test-log.cpp)
-llama_target_and_test(test-chat-template.cpp)
+if (LLAMA_LLGUIDANCE)
+    llama_build_and_test(test-grammar-llguidance.cpp ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-llama-bpe.gguf)
+endif ()
+
+if (NOT WIN32 OR NOT BUILD_SHARED_LIBS)
+    # these tests are disabled on Windows because they use internal functions not exported with LLAMA_API (when building with shared libraries)
+    llama_build_and_test(test-sampling.cpp)
+    llama_build_and_test(test-grammar-parser.cpp)
+    llama_build_and_test(test-grammar-integration.cpp)
+    llama_build_and_test(test-llama-grammar.cpp)
+    llama_build_and_test(test-chat.cpp)
+    # TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8
+    if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
+        llama_build_and_test(test-json-schema-to-grammar.cpp   WORKING_DIRECTORY ${PROJECT_SOURCE_DIR})
+        target_include_directories(test-json-schema-to-grammar PRIVATE ${PROJECT_SOURCE_DIR}/tools/server)
+    endif()
+
+    if (NOT GGML_BACKEND_DL)
+        llama_build(test-quantize-stats.cpp)
+    endif()
+
+    llama_build(test-gbnf-validator.cpp)
+
+    # build test-tokenizer-1-bpe target once and add many tests
+    llama_build(test-tokenizer-1-bpe.cpp)
+
+    # TODO: disabled due to slowness
+    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-aquila    ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-aquila.gguf)
+    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-falcon    ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-falcon.gguf)
+    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-gpt-2     ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-gpt-2.gguf)
+    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-gpt-neox  ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-gpt-neox.gguf)
+    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-llama-bpe ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-llama-bpe.gguf --ignore-merges)
+    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-mpt       ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-mpt.gguf)
+    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-refact    ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-refact.gguf)
+    #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-starcoder ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-starcoder.gguf)
+
+    # build test-tokenizer-1-spm target once and add many tests
+    llama_build(test-tokenizer-1-spm.cpp)
+
+    llama_test(test-tokenizer-1-spm  NAME test-tokenizer-1-llama-spm ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-llama-spm.gguf)
+    #llama_test(test-tokenizer-1-spm  NAME test-tokenizer-1-baichuan  ARGS ${PROJECT_SOURCE_DIR}/models/ggml-vocab-baichuan.gguf)
+
+    # llama_build_and_test(test-double-float.cpp) # SLOW
+endif()
+
+llama_build_and_test(test-chat-parser.cpp)
+llama_build_and_test(test-chat-template.cpp)
+llama_build_and_test(test-json-partial.cpp)
+llama_build_and_test(test-log.cpp)
+llama_build_and_test(test-regex-partial.cpp)
+
+llama_build_and_test(test-thread-safety.cpp ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf -ngl 99 -p "The meaning of life is" -n 128 -c 256 -ub 32 -np 4)
 
 # this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135)
 if (NOT WIN32)
-    llama_target_and_test(test-arg-parser.cpp)
+    llama_build_and_test(test-arg-parser.cpp)
 endif()
 
-# llama_target_and_test(test-opt.cpp) # SLOW
-llama_target_and_test(test-gguf.cpp)
-llama_target_and_test(test-backend-ops.cpp)
+# llama_build_and_test(test-opt.cpp) # SLOW
+llama_build_and_test(test-gguf.cpp)
+llama_build_and_test(test-backend-ops.cpp)
 
-llama_target_and_test(test-model-load-cancel.cpp  LABEL "model")
-llama_target_and_test(test-autorelease.cpp        LABEL "model")
+llama_build_and_test(test-model-load-cancel.cpp  LABEL "model")
+llama_build_and_test(test-autorelease.cpp        LABEL "model")
 
 if (NOT GGML_BACKEND_DL)
     # these tests use the backends directly and cannot be built with dynamic loading
-    llama_target_and_test(test-barrier.cpp)
-    llama_target_and_test(test-quantize-fns.cpp)
-    llama_target_and_test(test-quantize-perf.cpp)
-    llama_target_and_test(test-rope.cpp)
+    llama_build_and_test(test-barrier.cpp)
+    llama_build_and_test(test-quantize-fns.cpp)
+    llama_build_and_test(test-quantize-perf.cpp)
+    llama_build_and_test(test-rope.cpp)
 endif()
 
+# libmtmd
+set(LLAMA_TEST_NAME test-mtmd-c-api)
+llama_build_and_test(test-mtmd-c-api.c)
+target_link_libraries(${LLAMA_TEST_NAME} PRIVATE mtmd)
 
 # dummy executable - not installed
 get_filename_component(TEST_TARGET test-c.c NAME_WE)
diff --git a/tests/run-json-schema-to-grammar.mjs b/tests/run-json-schema-to-grammar.mjs
index b20ac1d6b..450c3dde0 100644
--- a/tests/run-json-schema-to-grammar.mjs
+++ b/tests/run-json-schema-to-grammar.mjs
@@ -1,5 +1,5 @@
 import { readFileSync } from "fs"
-import { SchemaConverter } from "../examples/server/public_legacy/json-schema-to-grammar.mjs"
+import { SchemaConverter } from "../tools/server/public_legacy/json-schema-to-grammar.mjs"
 
 const [, , file] = process.argv
 const url = `file://${file}`
diff --git a/tests/test-arg-parser.cpp b/tests/test-arg-parser.cpp
index 537fc63a4..e2836ca48 100644
--- a/tests/test-arg-parser.cpp
+++ b/tests/test-arg-parser.cpp
@@ -126,6 +126,53 @@ int main(void) {
     assert(params.cpuparams.n_threads == 1010);
 #endif // _WIN32
 
+    if (common_has_curl()) {
+        printf("test-arg-parser: test curl-related functions\n\n");
+        const char * GOOD_URL = "https://ggml.ai/";
+        const char * BAD_URL  = "https://www.google.com/404";
+        const char * BIG_FILE = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v1.bin";
+
+        {
+            printf("test-arg-parser: test good URL\n\n");
+            auto res = common_remote_get_content(GOOD_URL, {});
+            assert(res.first == 200);
+            assert(res.second.size() > 0);
+            std::string str(res.second.data(), res.second.size());
+            assert(str.find("llama.cpp") != std::string::npos);
+        }
+
+        {
+            printf("test-arg-parser: test bad URL\n\n");
+            auto res = common_remote_get_content(BAD_URL, {});
+            assert(res.first == 404);
+        }
+
+        {
+            printf("test-arg-parser: test max size error\n");
+            common_remote_params params;
+            params.max_size = 1;
+            try {
+                common_remote_get_content(GOOD_URL, params);
+                assert(false && "it should throw an error");
+            } catch (std::exception & e) {
+                printf("  expected error: %s\n\n", e.what());
+            }
+        }
+
+        {
+            printf("test-arg-parser: test timeout error\n");
+            common_remote_params params;
+            params.timeout = 1;
+            try {
+                common_remote_get_content(BIG_FILE, params);
+                assert(false && "it should throw an error");
+            } catch (std::exception & e) {
+                printf("  expected error: %s\n\n", e.what());
+            }
+        }
+    } else {
+        printf("test-arg-parser: no curl, skipping curl-related functions\n");
+    }
 
     printf("test-arg-parser: all tests OK\n\n");
 }
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 3a5741c8d..509a4b35f 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -823,7 +823,7 @@ struct test_case {
 
         ggml_build_forward_expand(gf, out);
         ggml_graph_cpy(gf, gb);
-        ggml_build_backward_expand(ctx.get(), ctx.get(), gb, false);
+        ggml_build_backward_expand(ctx.get(), gb, nullptr);
         if (expect.size() != 1 || expect[0] != 0.0f) {
             GGML_ASSERT(ggml_graph_n_nodes(gb) > ggml_graph_n_nodes(gf));
             for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
@@ -1026,7 +1026,7 @@ struct test_example : public test_case {
         // Step 3: return the output tensor.
         return out;
     }
-    // In order to also check the gradients for your op, add calls like ggml_set_param(ctx, a)
+    // In order to also check the gradients for your op, add calls like ggml_set_param(a)
     // immediately after you create the tensors.
     // This is optional and only makes sense if a backward pass has actually been implemented for the new op.
 };
@@ -1058,7 +1058,7 @@ struct test_unary : public test_case {
             auto ne = ne_a; ne[0] *= 3;
             a = ggml_new_tensor(ctx, type, 4, ne.data());
             if (grad_supported) {
-                ggml_set_param(ctx, a);
+                ggml_set_param(a);
             }
             ggml_set_name(a, "a");
 
@@ -1067,7 +1067,7 @@ struct test_unary : public test_case {
         } else {
             a = ggml_new_tensor(ctx, type, 4, ne_a.data());
             if (grad_supported) {
-                ggml_set_param(ctx, a);
+                ggml_set_param(a);
             }
             ggml_set_name(a, "a");
         }
@@ -1133,7 +1133,7 @@ struct test_get_rows : public test_case {
 
         const bool grad_supported = ggml_is_matrix(in) && ggml_is_vector(rows);
         if (grad_supported) {
-            ggml_set_param(ctx, in);
+            ggml_set_param(in);
             // rows is a constant input -> no gradients
         }
 
@@ -1322,7 +1322,7 @@ struct test_repeat : public test_case {
         ggml_set_name(target, "target");
 
         ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, src);
+        ggml_set_param(src);
         ggml_set_name(src, "src");
 
         ggml_tensor * out = ggml_repeat(ctx, src, target);
@@ -1406,7 +1406,7 @@ struct test_dup : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, src);
+        ggml_set_param(src);
         ggml_set_name(src, "src");
 
         if (_use_permute) {
@@ -1442,7 +1442,7 @@ struct test_set : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
-        ggml_set_param(ctx, src);
+        ggml_set_param(src);
         ggml_set_name(src, "src");
 
         auto ne_dst = ne;
@@ -1450,7 +1450,7 @@ struct test_set : public test_case {
             ne_dst[i] *= 2;
         }
         ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, ne_dst.data());
-        ggml_set_param(ctx, dst);
+        ggml_set_param(dst);
         ggml_set_name(dst, "dst");
 
         size_t offset = 0;
@@ -1498,7 +1498,7 @@ struct test_cpy : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
-        ggml_set_param(ctx, src);
+        ggml_set_param(src);
         ggml_set_name(src, "src");
 
         if (_src_use_permute) {
@@ -1536,7 +1536,7 @@ struct test_cont : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, src);
+        ggml_set_param(src);
         ggml_set_name(src, "src");
 
         src = ggml_transpose(ctx, src);
@@ -1583,8 +1583,8 @@ struct test_bin_bcast : public test_case {
         // The backward pass supports broadcasting only for GGML_ADD:
         const bool grad_supported = op == ggml_add || ggml_are_same_shape(a, b);
         if (grad_supported) {
-            ggml_set_param(ctx, a);
-            ggml_set_param(ctx, b);
+            ggml_set_param(a);
+            ggml_set_param(b);
         }
 
         ggml_tensor * out = op(ctx, a, b);
@@ -1632,11 +1632,11 @@ struct test_add1 : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * b = ggml_new_tensor_1d(ctx, type, 1);
-        // ggml_set_param(ctx, b); // TODO: implement
+        // ggml_set_param(b); // TODO: implement
         ggml_set_name(b, "b");
 
         ggml_tensor * out = ggml_add1(ctx, a, b);
@@ -1667,7 +1667,7 @@ struct test_scale : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * out = ggml_scale(ctx, a, scale);
@@ -1762,7 +1762,7 @@ struct test_rms_norm : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         if (v) {
@@ -1981,7 +1981,7 @@ struct test_mul_mat : public test_case {
     const std::array bs;  // dims 3 and 4
     const std::array nr;  // repeat in dims 3 and 4
     const std::array per; // permutation of dimensions
-    const bool v; // whether a is a non-contiguous view
+    const bool v; // whether a and b are non-contiguous views
 
     std::string vars() override {
         return VARS_TO_STR9(type_a, type_b, m, n, k, bs, nr, per, v);
@@ -2028,9 +2028,9 @@ struct test_mul_mat : public test_case {
             b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]);
             if (!ggml_is_quantized(type_a)) {
                 if (bs[1] == 1 && nr[1] == 1) {
-                    ggml_set_param(ctx, a);
+                    ggml_set_param(a);
                 }
-                ggml_set_param(ctx, b);
+                ggml_set_param(b);
             }
             ggml_set_name(a, "a");
             ggml_set_name(b, "b");
@@ -2040,19 +2040,29 @@ struct test_mul_mat : public test_case {
             ggml_set_name(a, "a_permuted");
             ggml_set_name(b, "b_permuted");
         } else {
-
             if (v) {
-                a = ggml_new_tensor_4d(ctx, type_a, k*2, m, bs[0], bs[1]);
-                a = ggml_view_4d(ctx, a, k, m, bs[0], bs[1], a->nb[1], a->nb[2], a->nb[3], 0);
+                a = ggml_new_tensor_4d(ctx, type_a, k*2, m, bs[0],       bs[1]);
+                b = ggml_new_tensor_4d(ctx, type_b, k*2, n, bs[0]*nr[0], bs[1]*nr[1]);
+
+                if (!ggml_is_quantized(type_a)) {
+                    if (bs[1] == 1 && nr[1] == 1) {
+                        ggml_set_param(a);
+                    }
+                    ggml_set_param(b);
+                }
+
+                a = ggml_view_4d(ctx, a, k, m, bs[0],       bs[1],       a->nb[1], a->nb[2], a->nb[3], 0);
+                b = ggml_view_4d(ctx, b, k, n, bs[0]*nr[0], bs[1]*nr[1], b->nb[1], b->nb[2], b->nb[3], 0);
             } else {
                 a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0],       bs[1]);
-            }
-            b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
-            if (!ggml_is_quantized(type_a)) {
-                if (bs[1] == 1 && nr[1] == 1) {
-                    ggml_set_param(ctx, a);
+                b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
+
+                if (!ggml_is_quantized(type_a)) {
+                    if (bs[1] == 1 && nr[1] == 1) {
+                        ggml_set_param(a);
+                    }
+                    ggml_set_param(b);
                 }
-                ggml_set_param(ctx, b);
             }
             ggml_set_name(a, "a");
             ggml_set_name(b, "b");
@@ -2071,7 +2081,7 @@ struct test_mul_mat_id : public test_case {
     const ggml_type type_b;
     const int n_mats;
     const int n_used;
-    const bool b; // brodcast b matrix
+    const bool b; // broadcast b matrix
     const int64_t m;
     const int64_t n;
     const int64_t k;
@@ -2201,7 +2211,7 @@ struct test_sqr : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * out = ggml_sqr(ctx, a);
@@ -2230,7 +2240,7 @@ struct test_sqrt : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * out = ggml_sqrt(ctx, a);
@@ -2270,7 +2280,7 @@ struct test_log : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * out = ggml_log(ctx, a);
@@ -2306,7 +2316,7 @@ struct test_sin : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * out = ggml_sin(ctx, a);
@@ -2349,7 +2359,7 @@ struct test_cos : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * out = ggml_cos(ctx, a);
@@ -2429,7 +2439,7 @@ struct test_diag_mask_inf : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * out = ggml_diag_mask_inf(ctx, a, n_past);
@@ -2468,7 +2478,7 @@ struct test_soft_max : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * mask = nullptr;
@@ -2550,7 +2560,7 @@ struct test_rope : public test_case {
             auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;
             a = ggml_new_tensor(ctx, type, 4, ne.data());
             if (forward) {
-                ggml_set_param(ctx, a);
+                ggml_set_param(a);
             }
             ggml_set_name(a, "a");
 
@@ -2559,7 +2569,7 @@ struct test_rope : public test_case {
         } else {
             a = ggml_new_tensor(ctx, type, 4, ne_a.data());
             if (forward) {
-                ggml_set_param(ctx, a);
+                ggml_set_param(a);
             }
             ggml_set_name(a, "a");
         }
@@ -2606,6 +2616,8 @@ struct test_rope : public test_case {
             } else {
                 out = ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
             }
+
+            // TODO: add test with a non-contiguous view as input ; this case is needed for build_rope_2d in clip.cpp
         }
         ggml_set_name(out, "out");
 
@@ -2671,7 +2683,7 @@ struct test_pool2d : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
-        ggml_set_param(ctx, input);
+        ggml_set_param(input);
         ggml_set_name(input, "input");
 
         ggml_tensor * out = ggml_pool_2d(ctx, input, pool_type, k0, k1, s0, s1, p0, p1);
@@ -2694,8 +2706,8 @@ struct test_conv_transpose_1d : public test_case {
         return VARS_TO_STR5(ne_input, ne_kernel, s0, p0, d0);
     }
 
-    test_conv_transpose_1d(std::array ne_input = {197, 32, 1, 1}, // [input_width, input_height, input_channels, 1]
-                           std::array ne_kernel = {16, 32, 32, 1}, // [kernel_width, kernel_height, input_channels, 1]
+    test_conv_transpose_1d(std::array ne_input = {197, 32, 1, 1}, // [input_width, input_channels, 1 /* assert in cpu kernel*/, 1 (should be batch)]
+                           std::array ne_kernel = {16, 32, 32, 1}, // [kernel_width, output_channels, input_channels, 1 (should be batch)]
                            int s0 = 1, int p0 = 0, int d0 = 1)
         : ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), p0(p0), d0(d0) {}
 
@@ -2747,7 +2759,7 @@ struct test_im2col : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
-        ggml_set_param(ctx, input);
+        ggml_set_param(input);
         ggml_set_name(input, "input");
 
         ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());
@@ -2760,6 +2772,48 @@ struct test_im2col : public test_case {
     }
 };
 
+// GGML_OP_CONV_2D_DW
+struct test_conv_2d_dw : public test_case {
+    const std::array ne_input;
+    const std::array ne_kernel;
+    const int stride;
+    const int padding;
+    const int dilation;
+    const bool cwhn;
+
+    std::string vars() override {
+        return VARS_TO_STR6(ne_input, ne_kernel, stride, padding, dilation, cwhn);
+    }
+
+    test_conv_2d_dw(std::array ne_input = {64, 64, 16, 1},
+            std::array ne_kernel = {3, 3, 1, 16},
+            int stride = 1, int padding = 0, int dilation = 1, bool cwhn = false)
+        : ne_input(ne_input), ne_kernel(ne_kernel), stride(stride), padding(padding), dilation(dilation), cwhn(cwhn) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
+        ggml_set_name(input, "input");
+
+        ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_kernel.data());
+        ggml_set_name(kernel, "kernel");
+
+        if (cwhn) {
+            // change memory layout to channel-most-contiguous (CWHN),
+            // then permute it back so NE matches the original input
+            input = ggml_cont(ctx, ggml_permute(ctx, input, 1, 2, 0, 3));
+            input = ggml_permute(ctx, input, 2, 0, 1, 3);
+            kernel = ggml_cont(ctx, ggml_permute(ctx, kernel, 2, 3, 1, 0));
+            kernel = ggml_permute(ctx, kernel, 3, 2, 0, 1);
+        }
+
+        ggml_tensor * out = ggml_conv_2d_dw_direct(
+            ctx, kernel, input,
+            stride, stride, padding, padding, dilation, dilation);
+        ggml_set_name(out, "out");
+        return out;
+    }
+};
+
 // GGML_OP_CONCAT
 struct test_concat : public test_case {
     const ggml_type type;
@@ -2882,7 +2936,7 @@ struct test_sum : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * out = ggml_sum(ctx, a);
@@ -2911,7 +2965,7 @@ struct test_sum_rows : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * out = ggml_sum_rows(ctx, a);
@@ -2936,7 +2990,7 @@ struct test_mean : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * out = ggml_mean(ctx, a);
@@ -3082,11 +3136,11 @@ struct test_acc : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
-        ggml_set_param(ctx, a);
+        ggml_set_param(a);
         ggml_set_name(a, "a");
 
         ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
-        ggml_set_param(ctx, b);
+        ggml_set_param(b);
         ggml_set_name(b, "b");
 
         ggml_tensor * out = ggml_acc(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], b->nb[1]);
@@ -3323,7 +3377,7 @@ struct test_cross_entropy_loss : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * logits = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_set_param(ctx, logits);
+        ggml_set_param(logits);
         ggml_set_name(logits, "logits");
 
         ggml_tensor * labels = ggml_new_tensor(ctx, type, 4, ne.data());
@@ -3405,7 +3459,7 @@ struct test_opt_step_adamw : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
-        ggml_set_param(ctx, a); // Despite tensor a having gradients the output tensor will not.
+        ggml_set_param(a); // Despite tensor a having gradients the output tensor will not.
         ggml_set_name(a, "a");
 
         ggml_tensor * grad = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]);
@@ -3970,6 +4024,23 @@ static std::vector> make_test_cases_eval() {
     // test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true));
     // test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {1024, 1024, 256, 1}, {3, 3, 256, 1}, 1, 1, 1, 1, 1, 1, true));
 
+    test_cases.emplace_back(new test_conv_2d_dw({17, 34, 9, 1}, {3, 3, 1, 9}, 1, 0, 1, false));
+    test_cases.emplace_back(new test_conv_2d_dw({17, 34, 9, 1}, {3, 3, 1, 9}, 1, 0, 1, true));
+    test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, false));
+    test_cases.emplace_back(new test_conv_2d_dw({32, 8, 64, 1}, {3, 3, 1, 64}, 2, 1, 1, true));
+
+    for(uint32_t Cout : {1, 9}){
+        for(uint32_t Cin : {1, 7}){
+            for(uint32_t K : {1, 3, 1337}){
+                for(uint32_t L : {1, 2, 13}){
+                    for(uint32_t s0: {1, 2, 3}){
+                        test_cases.emplace_back(new test_conv_transpose_1d({L,Cin,1,1}, {K,Cout,Cin,1}, s0, 0, 1));
+                    }
+                }
+            }
+        }
+    }
+
     test_cases.emplace_back(new test_conv_transpose_1d());
     test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 3, 0, 1));
     test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 2, 0, 1));
@@ -4182,6 +4253,11 @@ static std::vector> make_test_cases_eval() {
             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
             test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
+
+            // test cases with large ne00/ne10 to cover stream-k fixup
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  1, 1024, {3, 2}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16,  8, 1024, {3, 2}, {1, 1}));
+            test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 1024, {3, 2}, {1, 1}));
         }
     }
     for (ggml_type type_a : other_types) {
@@ -4428,10 +4504,11 @@ static std::vector> make_test_cases_eval() {
     test_cases.emplace_back(new test_timestep_embedding());
     test_cases.emplace_back(new test_leaky_relu());
 
-    for (int hsk : { 64, 80, 128, 192, 256, }) {
-        for (int hsv : { 64, 80, 128, 192, 256, }) {
-            if (hsk != 192 && hsk != hsv) continue;
+    for (int hsk : { 64, 80, 128, 192, 256, 576 }) {
+        for (int hsv : { 64, 80, 128, 192, 256, 512 }) {
+            if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
             if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
+            if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
 
             for (bool mask : { true, false } ) {
                 for (float max_bias : { 0.0f, 8.0f }) {
@@ -4532,10 +4609,15 @@ static std::vector> make_test_cases_perf() {
 
     for (int kv : { 4096, 8192, 16384, }) {
         for (int hs : { 64, 128, }) {
-            test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, 4, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
+            for (int nr : { 1, 4, }) {
+                test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, nr, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
+            }
         }
     }
 
+    test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false));
+    test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true));
+
     return test_cases;
 }
 
diff --git a/tests/test-chat-parser.cpp b/tests/test-chat-parser.cpp
new file mode 100644
index 000000000..59e44e07d
--- /dev/null
+++ b/tests/test-chat-parser.cpp
@@ -0,0 +1,352 @@
+//  Tests chat handling, including grammar generation and parsing for tool calling, for various templates.
+//
+//  Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates,
+//  e.g. given Minja (http://github.com/google/minja) checked out in parent dir:
+//
+//    cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
+//
+#include 
+#include 
+#include 
+
+#include "chat-parser.h"
+#include "common.h"
+#include "log.h"
+#include "regex-partial.h"
+
+template 
+static void assert_equals(const T & expected, const T & actual) {
+    if (expected != actual) {
+        std::cerr << "Expected: " << expected << std::endl;
+        std::cerr << "Actual: " << actual << std::endl;
+        std::cerr << std::flush;
+        throw std::runtime_error("Test failed");
+    }
+}
+static void assert_equals(const char * expected, const std::string & actual) {
+  return assert_equals(expected, actual);
+}
+
+static void assert_throws(const std::function & fn, const std::string & expected_exception_pattern = "") {
+    try {
+        fn();
+    } catch (const std::exception & e) {
+      if (expected_exception_pattern.empty()) {
+          return;
+        }
+        std::regex expected_exception_regex(expected_exception_pattern);
+        std::string actual_message = e.what();
+        if (std::regex_search(actual_message, expected_exception_regex)) {
+            return;
+        }
+        throw std::runtime_error("Exception doesn't match expected pattern: " + actual_message + " (pattern: " + expected_exception_pattern + ")");
+        throw std::runtime_error("Exception of unexpected type: " + std::string(e.what()));
+    }
+    throw std::runtime_error("Exception was expected but not thrown");
+}
+
+static void test_reasoning() {
+  {
+    common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, {
+        /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
+        /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
+        /* .reasoning_in_content = */ false,
+        /* .thinking_forced_open = */ false,
+    });
+    assert_equals(false, builder.try_parse_reasoning("", ""));
+    assert_equals("CogitoErgo sum", builder.consume_rest());
+  }
+  {
+    common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, {
+        /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
+        /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+        /* .reasoning_in_content = */ false,
+        /* .thinking_forced_open = */ false,
+    });
+    assert_equals(true, builder.try_parse_reasoning("", ""));
+    assert_equals(std::string("Cogito"), builder.result().reasoning_content);
+    assert_equals("Ergo sum", builder.consume_rest());
+  }
+  {
+    common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, {
+        /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
+        /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE,
+        /* .reasoning_in_content = */ false,
+        /* .thinking_forced_open = */ false,
+    });
+    assert_equals(false, builder.try_parse_reasoning("", ""));
+    assert_equals("CogitoErgo sum", builder.consume_rest());
+  }
+  {
+    common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, {
+        /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
+        /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+        /* .reasoning_in_content = */ false,
+        /* .thinking_forced_open = */ true,
+    });
+    assert_equals(true, builder.try_parse_reasoning("", ""));
+    assert_equals(std::string("Cogito"), builder.result().reasoning_content);
+    assert_equals("Ergo sum", builder.consume_rest());
+  }
+  {
+    common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, {
+        /* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
+        /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+        /* .reasoning_in_content = */ true,
+        /* .thinking_forced_open = */ true,
+    });
+    assert_equals(true, builder.try_parse_reasoning("", ""));
+    assert_equals("Cogito", builder.result().content);
+    assert_equals("Ergo sum", builder.consume_rest());
+  }
+}
+
+static void test_regex() {
+  auto test_throws = [](const std::string & input, const std::string & regex, const std::string & expected_exception_pattern = "") {
+    common_chat_msg_parser builder(input, /* is_partial= */ false, {});
+    assert_throws([&]() { builder.consume_regex(common_regex(regex)); }, expected_exception_pattern);
+  };
+
+  test_throws("Hello, world!", "abc", "^abc$");
+  test_throws("Hello, world!", "e", "^e$");
+
+  {
+    common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
+    builder.consume_regex(common_regex("Hello"));
+    assert_equals(", world!", builder.consume_rest());
+  }
+
+  {
+    // When in non partial mode, we can say whether the regex was consumed or not.
+    common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
+    assert_equals(false, builder.try_consume_regex(common_regex("Hello, world!")).has_value());
+  }
+  {
+    common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {});
+    auto res = builder.try_consume_regex(common_regex("H(el)l(?:o, world!)?"));
+    assert_equals(true, res.has_value());
+    // Verify captures
+    assert_equals(2, res->groups.size());
+    assert_equals("Hell", builder.str(res->groups[0]));
+    assert_equals("el", builder.str(res->groups[1]));
+    // Verify position is after the match
+    assert_equals(4, builder.pos());
+    assert_equals("o,", builder.consume_rest());
+  }
+  {
+    // But in partial mode, we have a partial final match / can't decide, so we throw a partial exception.
+    common_chat_msg_parser builder("Hello,", /* is_partial= */ true, {});
+    assert_throws([&]() {
+      builder.try_consume_regex(common_regex("Hello, world!"));
+    }, "^Hello, world!$");
+  }
+
+  // Now regardless of the mode, we can tell these aren't a match.
+  for (const auto is_partial : {false, true}) {
+    common_chat_msg_parser builder("Hello,", is_partial, {});
+    assert_equals(false, builder.try_consume_regex(common_regex("a(b|c)(d|e)f")).has_value());
+  }
+  for (const auto is_partial : {false, true}) {
+    common_chat_msg_parser builder("Hello,", is_partial, {});
+    assert_equals(false, builder.try_consume_literal("Oh"));
+  }
+}
+
+const std::vector barely_healable_jsons = {
+  "{",
+  "{\"",
+  "{\"\\",
+  "{\"n",
+  "{\"name\"",
+  "{\"name\":",
+  "{\"name\":\"",
+  "{\"name\":\"\\",
+  "{\"name\":\"python",
+  "{\"name\":\"python\\",
+  "{\",",
+  "{\":",
+  "{\"[",
+  "{\"]",
+  "{\"{",
+  "{\"}",
+  "{\"1",
+  "{\"name\":\",",
+  "{\"name\":\":",
+  "{\"name\":\"[",
+  "{\"name\":\"]",
+  "{\"name\":\"{",
+  "{\"name\":\"}",
+  "{\"name\":\"1",
+};
+
+static void test(const std::string & input, bool is_partial, const std::vector> & args_paths, const std::vector> & content_paths, const std::string & expected) {
+  common_chat_msg_parser builder(input, is_partial, {});
+  auto js = builder.try_consume_json_with_dumped_args(args_paths, content_paths);
+  assert_equals(true, js.has_value());
+  assert_equals(is_partial, js->is_partial);
+  assert_equals(expected, args_paths.size() == 1 && args_paths[0].empty() ? js->value.get() : js->value.dump());
+}
+static void test_with_args(const std::string & input, const std::string & expected, bool parse_as_partial = true, bool is_partial = true) {
+  common_chat_msg_parser builder(input, parse_as_partial, {});
+  auto js = builder.try_consume_json_with_dumped_args({{"args"}}, {});
+  assert_equals(true, js.has_value());
+  assert_equals(is_partial, js->is_partial);
+  assert_equals(expected, js->value.dump());
+}
+
+static void test_json_with_dumped_args_no_args() {
+  // Normal JSON, nothing to heal, nothing to dump
+  test("{\"name\": \"python\"}", false, {}, {}, "{\"name\":\"python\"}");
+  // Full json is args
+  test("{\"name\": \"python\"}", false, {{}}, {}, "{\"name\":\"python\"}");
+
+  // If the arguments are further down, don't heal partial content.
+  for (const auto & src : barely_healable_jsons) {
+    test(src, true, {{"arguments"}}, {}, "{}");
+  }
+  // But heal content that isn't partial.
+  test("{\"name\": \"python\"", true, {{"arguments"}}, {}, "{\"name\":\"python\"}");
+}
+
+static void test_json_with_dumped_args() {
+
+  // Partial content.
+  test("{\"content\": \"t", true, {}, {{"content"}}, "{\"content\":\"t\"}");
+  test("{\"content\": \"", true, {}, {{"content"}}, "{\"content\":\"\"}");
+  test("{\"content\": ", true, {}, {{"content"}}, "{}");
+
+  // If the entire JSON is the arguments, healing it them dumping it produces the same output as the input (just reformatted).
+  test("{\"name\": \"python", true, {{}}, {}, "{\"name\":\"python");
+  for (const auto & src : barely_healable_jsons) {
+    test(src, true, {{}}, {}, src);
+  }
+
+  // Full JSON w/ args
+  for (auto parse_as_partial : {true, false}) {
+    test_with_args(
+      R"({"name": "python", "args": {"arg1": 1}})",
+      R"({"name":"python","args":"{\"arg1\":1}"})",
+      parse_as_partial,
+      /* is_partial= */ false
+    );
+  }
+
+  // Partial JSON w/ partial args
+  test_with_args(
+    R"({"foo": "bar", "args": {")",
+    R"({"foo":"bar","args":"{\""})"
+  );
+  // Partial args broken in object key
+  test_with_args(
+    R"({"foo": "bar", "args": {"ar)",
+    R"({"foo":"bar","args":"{\"ar"})"
+  );
+  // Partial args broken after object key
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1")",
+    R"({"foo":"bar","args":"{\"arg1\""})"
+  );
+  // Partial args broken before object value
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1":)",
+    R"({"foo":"bar","args":"{\"arg1\":"})"
+  );
+  // Partial args broken before object value (space)
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1": )",
+    R"({"foo":"bar","args":"{\"arg1\":"})"
+  );
+  // Partial args broken in object value that may not be complete (int)
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1": 1)",
+    R"({"foo":"bar","args":"{\"arg1\":"})"
+  );
+  // Partial args broken in object value that is complete (int)
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1": 1 )",
+    R"({"foo":"bar","args":"{\"arg1\":1"})"
+  );
+  // Partial args broken in object value that is incomplete (string)
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1": ")",
+    R"({"foo":"bar","args":"{\"arg1\":\""})"
+  );
+  // Partial args broken in object value that is complete (string)
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1": "1")",
+    R"({"foo":"bar","args":"{\"arg1\":\"1\""})"
+  );
+  // Partial args broken on array opening
+  test_with_args(
+    R"({"foo": "bar", "args": [)",
+    R"({"foo":"bar","args":"["})"
+  );
+  // Partial args broken on array value that is incomplete (int)
+  test_with_args(
+    R"({"foo": "bar", "args": [1)",
+    R"({"foo":"bar","args":"["})"
+  );
+  // Partial args broken on array value that is complete (int)
+  test_with_args(
+    R"({"foo": "bar", "args": [1 )",
+    R"({"foo":"bar","args":"[1"})"
+  );
+  // Partial args broken on array value that is complete (string)
+  test_with_args(
+    R"({"foo": "bar", "args": ["1")",
+    R"({"foo":"bar","args":"[\"1\""})"
+  );
+  // Partial args broken after array value
+  test_with_args(
+    R"({"foo": "bar", "args": [1,)",
+    R"({"foo":"bar","args":"[1,"})"
+  );
+  // Partial args broken on nested array
+  test_with_args(
+    R"({"foo": "bar", "args": {"arg1": [)",
+    R"({"foo":"bar","args":"{\"arg1\":["})"
+  );
+}
+
+static void test_positions() {
+  {
+    common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {});
+    assert_equals(0, builder.pos());
+    assert_throws([&]() { builder.move_to(100); });
+    assert_equals(0, builder.pos());
+    assert_throws([&]() { builder.move_back(1); });
+    assert_equals(0, builder.pos());
+
+    builder.move_to(8);
+    assert_equals(8, builder.pos());
+    builder.move_back(1);
+    assert_equals(7, builder.pos());
+    assert_equals("world!", builder.consume_rest());
+
+    builder.move_to(0);
+    assert_equals(0, builder.pos());
+
+    assert_throws([&]() { builder.finish(); });
+    assert_equals(0, builder.pos());
+
+    builder.move_to(builder.input().size());
+    builder.finish();
+  }
+  {
+    common_chat_msg_parser builder("Hello, world!", /* is_partial= */ true, {});
+
+    builder.move_to(builder.input().size());
+    assert_equals(builder.input().size(), builder.pos());
+    builder.finish();
+  }
+}
+
+int main() {
+    test_positions();
+    test_json_with_dumped_args_no_args();
+    test_json_with_dumped_args();
+    test_reasoning();
+    test_regex();
+    std::cout << "All tests passed!\n";
+    return 0;
+}
diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp
index be1a64006..a0a50f988 100644
--- a/tests/test-chat-template.cpp
+++ b/tests/test-chat-template.cpp
@@ -181,8 +181,8 @@ int main(void) {
         },
         {
             /* .name= */ "ChatGLM4",
-            /* .template_str= */ U8C("[gMASK]{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}"),
-            /* .expected_output= */ "[gMASK]<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n   I am an assistant   <|user|>\nAnother question<|assistant|>",
+            /* .template_str= */ U8C("[gMASK]{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>\n{% endif %}"),
+            /* .expected_output= */ "[gMASK]<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n   I am an assistant   <|user|>\nAnother question<|assistant|>\n",
             /* .expected_output_jinja= */ "",
             /* .bos_token= */ "",
             /* .eos_token= */ "",
diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp
index a0bf6affe..6ebf1464d 100644
--- a/tests/test-chat.cpp
+++ b/tests/test-chat.cpp
@@ -5,20 +5,82 @@
 //
 //    cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null
 //
+#include "chat.h"
+
+#include "log.h"
+
+#include "../src/unicode.h"
+#include "../src/llama-grammar.h"
+
+#include 
+
 #include 
 #include 
-#include 
 #include 
 
-#include "chat.h"
-#include "llama-grammar.h"
-#include "unicode.h"
-
 using json = nlohmann::ordered_json;
 
+static std::ostream & operator<<(std::ostream & os, const common_chat_msg_diff & diff) {
+    os << "{ content_delta: " << diff.content_delta << "; ";
+    os << "reasoning_content_delta: " << diff.reasoning_content_delta << "; ";
+    if (diff.tool_call_index != std::string::npos) {
+        os << "tool_call_index: " << diff.tool_call_index << "; ";
+        os << "tool_call_delta.name: " << diff.tool_call_delta.name << "; ";
+        os << "tool_call_delta.id: " << diff.tool_call_delta.id << "; ";
+        os << "tool_call_delta.arguments: " << diff.tool_call_delta.arguments << "; ";
+    }
+    os << "}";
+    return os;
+}
+// operator<< for vector:
+static std::ostream & operator<<(std::ostream & os, const std::vector & diffs) {
+    os << "[\n";
+    for (const auto & diff : diffs) {
+        os << "  " << diff << ",\n";
+    }
+    os << "]";
+    return os;
+}
+static std::ostream & operator<<(std::ostream & os, const common_chat_msg & msg) {
+    os << "{ role: " << msg.role << "; ";
+    os << "content: " << msg.content << "; ";
+    os << "content_parts: [\n";
+    for (const auto & part : msg.content_parts) {
+        os << "  { type: " << part.type << "; text: " << part.text << " },\n";
+    }
+    os << "]; ";
+    os << "reasoning_content: " << msg.reasoning_content << "; ";
+    os << "tool_calls: [\n";
+    for (const auto & tool_call : msg.tool_calls) {
+        os << "  { name: " << tool_call.name << "; arguments: " << tool_call.arguments << "; id: " << tool_call.id << " },\n";
+    }
+    os << "]";
+    os << "}";
+    return os;
+}
+
+template  static bool equals(const T & expected, const T & actual) {
+    return expected == actual;
+}
+
+static common_chat_msg normalize(const common_chat_msg & msg) {
+    common_chat_msg normalized = msg;
+    for (auto & tool_call : normalized.tool_calls) {
+        try {
+            tool_call.arguments = json::parse(tool_call.arguments).dump();
+        } catch (const std::exception &) {
+            // Do nothing
+        }
+    }
+    return normalized;
+}
+template <>
+bool equals(const common_chat_msg & expected, const common_chat_msg & actual) {
+    return normalize(expected) == normalize(actual);
+}
 
 template  static void assert_equals(const T & expected, const T & actual) {
-    if (expected != actual) {
+    if (!equals(expected, actual)) {
         std::cerr << "Expected: " << expected << std::endl;
         std::cerr << "Actual: " << actual << std::endl;
         std::cerr << std::flush;
@@ -76,6 +138,15 @@ static bool match_string(const std::string & input, llama_grammar * grammar) {
     return false;
 }
 
+static std::string renormalize_json(const std::string & json_str) {
+    try {
+        auto json_obj = json::parse(json_str);
+        return json_obj.dump();
+    } catch (const std::exception & e) {
+        std::cerr << "Failed to parse JSON: " << e.what() << '\n';
+        return json_str;
+    }
+}
 static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) {
     assert_equals(expected.role, actual.role);
     assert_equals(expected.content, actual.content);
@@ -92,7 +163,7 @@ static void assert_msg_equals(const common_chat_msg & expected, const common_cha
         const auto & expected_tool_call = expected.tool_calls[i];
         const auto & actual_tool_call   = actual.tool_calls[i];
         assert_equals(expected_tool_call.name, actual_tool_call.name);
-        assert_equals(json::parse(expected_tool_call.arguments).dump(), json::parse(actual_tool_call.arguments).dump());
+        assert_equals(renormalize_json(expected_tool_call.arguments), renormalize_json(actual_tool_call.arguments));
         assert_equals(expected_tool_call.id, actual_tool_call.id);
     }
 }
@@ -151,14 +222,12 @@ static delta_data init_delta(const struct common_chat_templates * tmpls, const s
                              const common_chat_msg & user_message,
                              const common_chat_msg & delta_message,
                              const std::vector & tools,
-                             const common_chat_tool_choice & tool_choice,
-                             bool think = false) {
+                             const common_chat_tool_choice & tool_choice) {
     common_chat_templates_inputs inputs;
     inputs.parallel_tool_calls = true;
     inputs.messages.push_back(user_message);
     inputs.tools       = tools;
     inputs.tool_choice = tool_choice;
-    inputs.extract_reasoning = think;
     auto params_prefix = common_chat_templates_apply(tmpls, inputs);
 
     inputs.messages.push_back(delta_message);
@@ -210,19 +279,22 @@ static void test_templates(const struct common_chat_templates * tmpls, const std
                           const std::string & expected_delta = "",
                           bool expect_grammar_triggered = true,
                           bool test_grammar_if_triggered = true,
-                          bool think = false) {
+                          common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE) {
     common_chat_msg user_message;
     user_message.role = "user";
     user_message.content = "Hello, world!";
 
     for (const auto & tool_choice : std::vector {COMMON_CHAT_TOOL_CHOICE_AUTO, COMMON_CHAT_TOOL_CHOICE_REQUIRED}) {
-        auto data = init_delta(tmpls, end_tokens, user_message, test_message, tools, tool_choice, think);
+        auto data = init_delta(tmpls, end_tokens, user_message, test_message, tools, tool_choice);
         if (!expected_delta.empty()) {
             assert_equals(expected_delta, data.delta);
         }
 
         if (expect_grammar_triggered) {
-            const auto msg = common_chat_parse(data.delta, data.params.format);
+            common_chat_syntax syntax;
+            syntax.format = data.params.format;
+            syntax.reasoning_format = reasoning_format;
+            const auto msg = common_chat_parse(data.delta, /* is_partial= */ false, syntax);
             assert_msg_equals(test_message, msg);
         }
 
@@ -250,15 +322,25 @@ static void test_templates(const struct common_chat_templates * tmpls, const std
                     {
                         const auto & pattern = trigger.value;
                         if (std::regex_search(constrained, match, std::regex(pattern))) {
-                            pos = match.position();
+                            pos = match.position(1);
                         }
                         break;
                     }
-                    case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
+                    case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
                     {
                         const auto & pattern = trigger.value;
-                        if (std::regex_search(constrained, match, std::regex(pattern)) && match.position() == 0) {
-                            pos = 0;
+                        if (std::regex_match(constrained, match, std::regex(pattern))) {
+                            auto mpos = std::string::npos;
+                            for (size_t i = 1; i < match.size(); ++i) {
+                                if (match[i].length() > 0) {
+                                    mpos = match.position(i);
+                                    break;
+                                }
+                            }
+                            if (mpos == std::string::npos) {
+                                mpos = match.position(0);
+                            }
+                            pos = mpos;
                         }
                         break;
                     }
@@ -312,117 +394,42 @@ const common_chat_msg message_user_parts {
     /* .tool_name = */ "",
     /* .tool_call_id = */ "",
 };
-const common_chat_msg message_assist {
-    "assistant",
-    "Hello, world!\nWhat's up?",
-    /* .content_parts = */ {},
-    /* .tool_calls = */ {},
-    /* .reasoning_content = */ "",
-    /* .tool_name = */ "",
-    /* .tool_call_id = */ "",
-};
-const common_chat_msg message_assist_thoughts_unparsed_think {
-    "assistant",
-    "I'm thinkingHello, world!\nWhat's up?",
-    /* .content_parts = */ {},
-    /* .tool_calls = */ {},
-    /* .reasoning_content = */ "",
-    /* .tool_name = */ "",
-    /* .tool_call_id = */ "",
-};
-const common_chat_msg message_assist_thoughts_unparsed_r7b {
-    "assistant",
-    "<|START_THINKING|>I'm thinking<|END_THINKING|>Hello, world!\nWhat's up?",
-    /* .content_parts = */ {},
-    /* .tool_calls = */ {},
-    /* .reasoning_content = */ "",
-    /* .tool_name = */ "",
-    /* .tool_call_id = */ "",
-};
-const common_chat_msg message_assist_thoughts {
-    "assistant",
-    "Hello, world!\nWhat's up?",
-    /* .content_parts = */ {},
-    /* .tool_calls = */ {},
-    /* .reasoning_content = */ "I'm thinking",
-    /* .tool_name = */ "",
-    /* .tool_call_id = */ "",
-};
-const std::vector tool_calls {
-    { "special_function", "{\"arg1\": 1}", /* .id = */ "" },
-};
-const std::vector tool_calls_idx {
-    { "special_function", "{\"arg1\": 1}", /* .id = */ "0" },
-};
-const std::vector tool_calls_id {
-    { "special_function", "{\"arg1\": 1}", /* .id = */ "123456789" },
-};
+static common_chat_msg simple_assist_msg(const std::string & content, const std::string & reasoning_content = "", const std::string & tool_name = "", const std::string & arguments = "", const std::string & id = "") {
+    common_chat_msg msg;
+    msg.role = "assistant";
+    msg.content = content;
+    msg.reasoning_content = reasoning_content;
+    if (!tool_name.empty()) {
+        msg.tool_calls.push_back({ tool_name, arguments, id });
+    }
+    return msg;
+}
+const common_chat_msg message_assist                              = simple_assist_msg("Hello, world!\nWhat's up?");
+const common_chat_msg message_assist_empty                        = simple_assist_msg("");
+const common_chat_msg message_assist_thoughts_unparsed_deepseek   = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?");
+const common_chat_msg message_assist_thoughts_unparsed_md         = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?\n```json\n{}```");
+const common_chat_msg message_assist_thoughts_unparsed_md_partial = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?\n```json\n{}");
 
-const common_chat_msg message_assist_call {
-    "assistant",
-    "",
-    /* .content_parts = */ {},
-    tool_calls,
-    /* .reasoning_content = */ "",
-    /* .tool_name = */ "",
-    /* .tool_call_id = */ "",
-};
-const common_chat_msg message_assist_call_thoughts = {
-    "assistant",
-    /* .content = */ "",
-    /* .content_parts = */ {},
-    tool_calls,
-    /* .reasoning_content = */ "I'm\nthinking",
-    /* .tool_name = */ "",
-    /* .tool_call_id = */ "",
-};
-const common_chat_msg message_assist_call_thoughts_unparsed = {
-    "assistant",
-    /* .content = */ "I'm\nthinking",
-    /* .content_parts = */ {},
-    tool_calls,
-    /* .reasoning_content = */ "",
-    /* .tool_name = */ "",
-    /* .tool_call_id = */ "",
-};
-const common_chat_msg message_assist_call_id {
-    "assistant",
-    "",
-    /* .content_parts = */ {},
-    tool_calls_id,
-    /* .reasoning_content = */ "",
-    /* .tool_name = */ "",
-    /* .tool_call_id = */ "",
-};
-const common_chat_msg message_assist_call_idx {
-    "assistant",
-    "",
-    /* .content_parts = */ {},
-    tool_calls_idx,
-    /* .reasoning_content = */ "",
-    /* .tool_name = */ "",
-    /* .tool_call_id = */ "",
-};
-const common_chat_msg message_assist_call_python {
-    "assistant",
-    "",
-    /* .content_parts = */ {},
-    { { "python", "{\"code\": \"print('hey')\"}", /* .id = */ "" } },
-    /* .reasoning_content = */ "",
-    /* .tool_name = */ "",
-    /* .tool_call_id = */ "",
-};
-const common_chat_msg message_assist_call_code_interpreter {
-    "assistant",
-    "",
-    /* .content_parts = */ {},
-    { { "code_interpreter", "{\"code\": \"print('hey')\"}", /* .id = */ "" } },
-    /* .reasoning_content = */ "",
-    /* .tool_name = */ "",
-    /* .tool_call_id = */ "",
-};
+const common_chat_msg message_assist_thoughts_unparsed_r7b       = simple_assist_msg("<|START_THINKING|>I'm\nthinking<|END_THINKING|>Hello, world!\nWhat's up?");
+const common_chat_msg message_assist_thoughts                    = simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking");
+const common_chat_msg message_assist_thoughts_unopened_unparsed  = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?");
+const common_chat_msg message_assist_thoughts_no_content         = simple_assist_msg("", "I'm\nthinking");
+const common_chat_msg message_assist_call                        = simple_assist_msg("", "", "special_function", "{\"arg1\": 1}");
+const common_chat_msg message_assist_call_content                = simple_assist_msg("Hello, world!\nWhat's up?", "", "special_function", "{\"arg1\":1}");
+const common_chat_msg message_assist_call_empty_args             = simple_assist_msg("", "", "special_function");
+const common_chat_msg message_assist_call_cutoff_args            = simple_assist_msg("", "", "special_function", "{\"arg");
+const common_chat_msg message_assist_call_thoughts               = simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\":1}");
+const common_chat_msg message_assist_call_thoughts_unparsed      = simple_assist_msg("I'm\nthinking\n\n", "", "special_function", "{\"arg1\": 1}");
+const common_chat_msg message_assist_call_id                     = simple_assist_msg("", "", "special_function", "{\"arg1\":1}", /* .id = */ "123456789");
+const common_chat_msg message_assist_call_idx                    = simple_assist_msg("", "", "special_function", "{\"arg1\":1}", /* .id = */ "0");
+const common_chat_msg message_assist_thoughts_call_idx           = simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\": 1}", /* id = */ "0");
+const common_chat_msg message_assist_call_python                 = simple_assist_msg("", "", "python", "{\"code\":\"print('hey')\"}");
+const common_chat_msg message_assist_call_python_lines           = simple_assist_msg("", "", "python", "{\"code\":\"# This is a program:\\nprint('hey')\"}");
+const common_chat_msg message_assist_call_python_lines_unclosed  = simple_assist_msg("", "", "python", "{\"code\":\"# This is a program:\\nprint('hey')");
+const common_chat_msg message_assist_call_code_interpreter       = simple_assist_msg("", "", "code_interpreter", "{\"code\":\"print('hey')\"}");
 
 static void test_msgs_oaicompat_json_conversion() {
+    printf("[%s]\n", __func__);
     std::vector msgs{
         message_user,
         message_user_parts,
@@ -472,7 +479,7 @@ static void test_msgs_oaicompat_json_conversion() {
             "        \"type\": \"function\",\n"
             "        \"function\": {\n"
             "          \"name\": \"python\",\n"
-            "          \"arguments\": \"{\\\"code\\\": \\\"print('hey')\\\"}\"\n"
+            "          \"arguments\": \"{\\\"code\\\":\\\"print('hey')\\\"}\"\n"
             "        }\n"
             "      }\n"
             "    ]\n"
@@ -498,6 +505,7 @@ static void test_msgs_oaicompat_json_conversion() {
 }
 
 static void test_tools_oaicompat_json_conversion() {
+    printf("[%s]\n", __func__);
     std::vector tools{
         special_function_tool,
         python_tool,
@@ -542,29 +550,18 @@ static void test_tools_oaicompat_json_conversion() {
 }
 
 static void test_template_output_parsers() {
+    printf("[%s]\n", __func__);
 
     common_chat_templates_inputs inputs_no_tools;
     inputs_no_tools.messages                = {message_user};
-    inputs_no_tools.extract_reasoning       = false;
-
-    common_chat_templates_inputs inputs_no_tools_think;
-    inputs_no_tools_think.messages          = {message_user};
-    inputs_no_tools_think.extract_reasoning = true;
 
     common_chat_templates_inputs inputs_tools;
     inputs_tools.messages                   = {message_user};
     inputs_tools.tools                      = {special_function_tool};
-    inputs_tools.extract_reasoning          = false;
-
-    common_chat_templates_inputs inputs_tools_think;
-    inputs_tools_think.messages             = {message_user};
-    inputs_tools_think.tools                = {special_function_tool};
-    inputs_tools_think.extract_reasoning    = true;
 
     common_chat_templates_inputs inputs_tools_builtin;
     inputs_tools_builtin.messages           = {message_user};
     inputs_tools_builtin.tools              = {python_tool};
-    inputs_tools_builtin.extract_reasoning  = false;
 
     {
         // Not supported yet
@@ -576,44 +573,87 @@ static void test_template_output_parsers() {
         auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja");
         std::vector   end_tokens{ "<|END_OF_TURN_TOKEN|>" };
 
-        assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B,                   common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
-        assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B,                   common_chat_templates_apply(tmpls.get(), inputs_tools).format);
-        assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
+        for (const auto & inputs : { inputs_no_tools, inputs_tools }) {
+            auto params = common_chat_templates_apply(tmpls.get(), inputs);
+            assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, params.format);
+            assert_equals(false, params.thinking_forced_open);
+        }
 
         assert_msg_equals(message_assist,
             common_chat_parse(
                 "Hello, world!\nWhat's up?",
-                COMMON_CHAT_FORMAT_COMMAND_R7B));
-        assert_msg_equals(message_assist,
-            common_chat_parse(
-                "Hello, world!\nWhat's up?<|END_RESPONSE|>",
-                COMMON_CHAT_FORMAT_COMMAND_R7B));
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_COMMAND_R7B}));
         assert_msg_equals(message_assist,
             common_chat_parse(
                 "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
-                COMMON_CHAT_FORMAT_COMMAND_R7B));
-        assert_msg_equals(message_assist_thoughts_unparsed_r7b,
-            common_chat_parse(
-                "<|START_THINKING|>I'm thinking<|END_THINKING|>"
-                "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
-                COMMON_CHAT_FORMAT_COMMAND_R7B));
-        assert_msg_equals(message_assist_thoughts_unparsed_r7b,
-            common_chat_parse(
-                "<|START_THINKING|>I'm thinking<|END_THINKING|>"
-                "Hello, world!\nWhat's up?<|END_RESPONSE|>",
-                COMMON_CHAT_FORMAT_COMMAND_R7B));
-
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_COMMAND_R7B}));
         assert_msg_equals(message_assist_thoughts,
             common_chat_parse(
-                "<|START_THINKING|>I'm thinking<|END_THINKING|>"
+                "<|START_THINKING|>I'm\nthinking<|END_THINKING|>"
                 "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
-                COMMON_CHAT_FORMAT_COMMAND_R7B_EXTRACT_REASONING));
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
+        assert_msg_equals(message_assist_thoughts_unparsed_deepseek,
+            common_chat_parse(
+                "<|START_THINKING|>I'm\nthinking<|END_THINKING|>"
+                "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ true,
+                    /* .thinking_forced_open = */ false,
+                }));
+        assert_msg_equals(message_assist_thoughts_unparsed_r7b,
+            common_chat_parse(
+                "<|START_THINKING|>I'm\nthinking<|END_THINKING|>"
+                "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_COMMAND_R7B}));
+        assert_msg_equals(message_assist_thoughts,
+            common_chat_parse(
+                "<|START_THINKING|>I'm\nthinking<|END_THINKING|>"
+                "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
+        assert_msg_equals(message_assist_thoughts_call_idx,
+            common_chat_parse(
+                "<|START_THINKING|>I'm\nthinking<|END_THINKING|>"
+                "<|START_ACTION|>[\n"
+                "    {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
+                "]<|END_ACTION|>",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
+        assert_msg_equals(message_assist_thoughts_no_content,
+            common_chat_parse(
+                "<|START_THINKING|>I'm\nthinking<|END_THINKING|>"
+                "<|START_ACTION|>[\n"
+                "    {\"tool_call_id\": \"0\", \"tool_name\": \"special",
+                /* is_partial= */ true,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
 
         test_templates(tmpls.get(), end_tokens, message_assist_call_idx, tools,
                       "<|START_THINKING|><|END_THINKING|>"
                       "<|START_ACTION|>[\n"
                       "    {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n"
-                      "]<|END_ACTION|>");
+                      "]<|END_ACTION|>",
+                      /* expect_grammar_triggered= */ true,
+                      /* test_grammar_if_triggered= */ true,
+                      COMMON_REASONING_FORMAT_DEEPSEEK);
         test_templates(tmpls.get(), end_tokens, message_assist, tools,
                       "<|START_RESPONSE|>Hello, world!\n"
                       "What's up?<|END_RESPONSE|>",
@@ -633,11 +673,52 @@ static void test_template_output_parsers() {
 
         // Generic tool calls doesn't generate / parse content-only messages symmetrically.
 
+        assert_equals(
+            simple_assist_msg("{ \"tool_call\" : { \"name\" : \"t"),
+            common_chat_parse(
+                "{ \"tool_call\" : { \"name\" : \"t",
+                /* is_partial= */ true,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_GENERIC,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ true,
+                    /* .parse_tool_calls = */ false,
+                }));
+        assert_equals(
+            message_assist_empty,
+            common_chat_parse(
+                "{ \"tool_call\" : { \"name\" : \"t",
+                /* is_partial= */ true,
+                {COMMON_CHAT_FORMAT_GENERIC}));
+
+        assert_equals(
+            simple_assist_msg("", "", "puppeteer_screenshot", "{\"name\":\"servethehome_homepage\","),
+            common_chat_parse(
+                R"({"tool_call": {"name": "puppeteer_screenshot", "arguments": {"name": "servethehome_homepage",)",
+                /* is_partial= */ true,
+                {COMMON_CHAT_FORMAT_GENERIC}));
+
+        assert_equals(
+            message_assist_call_empty_args,
+            common_chat_parse(
+                "{ \"tool_call\" : { \"name\" : \"special_function\"",
+                /* is_partial= */ true,
+                {COMMON_CHAT_FORMAT_GENERIC}));
+        assert_equals(
+            message_assist_call_cutoff_args,
+            common_chat_parse(
+                "{ \"tool_call\" : { \"name\" : \"special_function\", \"arguments\" : { \"arg",
+                /* is_partial= */ true,
+                {COMMON_CHAT_FORMAT_GENERIC}));
+
         assert_msg_equals(message_assist,
-                          common_chat_parse("{\n"
-                                            "  \"response\": \"Hello, world!\\nWhat's up?\"\n"
-                                            "}",
-                                            common_chat_templates_apply(tmpls.get(), inputs_tools).format));
+            common_chat_parse(
+                "{\n"
+                "  \"response\": \"Hello, world!\\nWhat's up?\"\n"
+                "}",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_GENERIC}));
         test_templates(tmpls.get(), end_tokens, message_assist_call_id, tools,
                       "{\n"
                       "  \"tool_calls\": [\n"
@@ -662,11 +743,18 @@ static void test_template_output_parsers() {
             tmpls.get(), end_tokens, message_assist_call_id, tools,
             "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
     }
+    {
+        auto tmpls = read_templates("models/templates/Qwen-QwQ-32B.jinja");
+        std::vector end_tokens{ "<|im_end|>" };
+
+        assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
+    }
     {
         auto tmpls = read_templates("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja");
         std::vector end_tokens{ "<|im_end|>" };
 
-        assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
         assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
         assert_equals(
             COMMON_CHAT_FORMAT_HERMES_2_PRO,
@@ -682,114 +770,288 @@ static void test_template_output_parsers() {
                 .format);
 
         // Test parsing
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "\n"
-            "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
-            "",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "{\"arg1\": 1}",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "\n"
-            "{\"arg1\": 1}\n"
-            "",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "\n"
-            "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
-            "",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "\n"
-            "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
-            "",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "\n"
-            "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
-            "",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "```xml\n"
-            "\n"
-            "    {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
-            "\n"
-            "```",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "```xml\n"
-            "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
-            "```",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "```\n"
-            "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
-            "```",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "```\n"
-            "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
-            "```",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "```json\n"
-            "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
-            "```",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "```json\n"
-            "\n"
-            "                     {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}} \n"
-            "                     \n"
-            "``` ",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "\n"
-            "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
-            "",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "\n"
-            "  {\n"
-            "    \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}\n"
-            "  }\n"
-            "",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "\n"
-            "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
-            "",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_call, common_chat_parse(
-            "{\n  \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
+        assert_msg_equals(
+            simple_assist_msg("", "", "python", ""),
+            common_chat_parse(
+                "```json\n"
+                " { \"name\" : \"python\"",
+                /* is_partial= */ true,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            simple_assist_msg("Let's call something\n"),
+            common_chat_parse(
+                "Let's call something\n"
+                "{\"name\"",
+                /* is_partial= */ true,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
+        assert_msg_equals(
+            simple_assist_msg("Let's call something\n"),
+            common_chat_parse(
+                "Let's call something\n"
+                "{\"name",
+                /* is_partial= */ true,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
+        assert_msg_equals(message_assist_call_thoughts,
+            common_chat_parse(
+                // QwQ-32B's template adds a trailing  if add_generation_prompt
+                "I'm\nthinking\n"
+                "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ true,
+                }));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "\n"
+                "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(message_assist_call_content,
+            common_chat_parse(
+                "Hello, world!\nWhat's up?\n"
+                "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "{\"arg1\": 1}",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "\n"
+                "{\"arg1\": 1}\n"
+                "",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "\n"
+                "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "\n"
+                "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "\n"
+                "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "```xml\n"
+                "\n"
+                "    {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "\n"
+                "```",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "```xml\n"
+                "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "```",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "```\n"
+                "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "```",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "```\n"
+                "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "```",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "```json\n"
+                "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "```",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "```json\n"
+                "\n"
+                "                     {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}} \n"
+                "                     \n"
+                "``` ",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "\n"
+                "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "\n"
+                "  {\n"
+                "    \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}\n"
+                "  }\n"
+                "",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "\n"
+                "  {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
+                "",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(
+            message_assist_call,
+            common_chat_parse(
+                "{\n  \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
 
-        assert_msg_equals(message_assist_thoughts_unparsed_think,
-            common_chat_parse("I'm thinkingHello, world!\nWhat's up?",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
-        assert_msg_equals(message_assist_thoughts_unparsed_think,
-            common_chat_parse("I'm thinkingHello, world!\nWhat's up?",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO));
+        assert_msg_equals(
+            simple_assist_msg(
+                "This is not a tool call:",
+                "",
+                "special_function",
+                "{\"arg1\": 1}"),
+            common_chat_parse(
+                "This is not a tool call:\n"
+                "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(message_assist,
+            common_chat_parse(
+                "Hello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        assert_msg_equals(message_assist_thoughts_unparsed_deepseek,
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_HERMES_2_PRO}));
+        // assert_msg_equals(message_assist_thoughts_unparsed_deepseek,
+        //     common_chat_parse(
+        //         "I'm\nthinkingHello, world!\nWhat's up?",
+        //         COMMON_CHAT_FORMAT_HERMES_2_PRO));
         assert_msg_equals(message_assist_thoughts,
-            common_chat_parse("I'm thinkingHello, world!\nWhat's up?",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING));
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
         assert_msg_equals(message_assist_thoughts,
-            common_chat_parse("I'm thinkingHello, world!\nWhat's up?",
-            COMMON_CHAT_FORMAT_HERMES_2_PRO_EXTRACT_REASONING));
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ true,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
+        assert_msg_equals(message_assist_thoughts_unparsed_md,
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?\n```json\n{}```",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ true,
+                    /* .thinking_forced_open = */ false,
+                    /* .parse_tool_calls = */ false,
+                }));
+        assert_msg_equals(message_assist_thoughts_unparsed_md_partial,
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?\n```json\n{}```",
+                /* is_partial= */ true,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ true,
+                    /* .thinking_forced_open = */ false,
+                }));
+        assert_msg_equals(message_assist_thoughts_unopened_unparsed,
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
+        assert_msg_equals(message_assist_thoughts,
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ true,
+                }));
 
         test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
         test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
                       "\n"
                       "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n"
                       "");
-        test_templates(tmpls.get(), end_tokens, message_assist_call_python, tools,
+        test_templates(tmpls.get(), end_tokens, message_assist_call_python_lines, tools,
                       "\n"
-                      "{\"name\": \"python\", \"arguments\": {\"code\": \"print('hey')\"}}\n"
+                      "{\"name\": \"python\", \"arguments\": {\"code\":\"# This is a program:\\nprint('hey')\"}}\n"
                       "");
+        assert_msg_equals(
+            simple_assist_msg("", /* reasoning_content= */ "nah uhg"),
+            common_chat_parse(
+                "nah uhg",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
     }
     {
         auto tmpls = read_templates("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja");
@@ -805,6 +1067,13 @@ static void test_template_output_parsers() {
                           inputs_tools_builtin)
                           .format);
 
+        assert_equals(
+            message_assist_call,
+            common_chat_parse(
+                "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_LLAMA_3_X}));
+
         // test_templates(tmpls.get(), end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false);
         test_templates(tmpls.get(), end_tokens, message_assist_call_code_interpreter, llama_3_1_tools,
                       "<|python_tag|>code_interpreter.call(code=\"print('hey')\")");
@@ -831,7 +1100,25 @@ static void test_template_output_parsers() {
         assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
                       common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
         assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
-                      common_chat_templates_apply(tmpls.get(), inputs_tools).format);
+            common_chat_templates_apply(tmpls.get(), inputs_tools).format);
+        assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY,
+                        common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
+
+        for (auto is_partial : { false, true }) {
+            assert_equals(
+                message_assist_call,
+                common_chat_parse(
+                    "{\"arg1\": 1}",
+                    is_partial,
+                    {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1}));
+        }
+
+        assert_equals(
+            message_assist_call,
+            common_chat_parse(
+                "{\"arg1\": 1}<",
+                /* is_partial= */ true,
+                {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1}));
 
         test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
         test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
@@ -844,6 +1131,47 @@ static void test_template_output_parsers() {
         assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
         assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
 
+        assert_msg_equals(
+            simple_assist_msg(
+                "Hello, world!\nnono\nWhat's up?",
+                "",
+                "special_function",
+                "{\"arg1\": 1}"),
+            common_chat_parse(
+                "all\n"
+                "Hello, world!\n"
+                "nono\n"
+                "What's up?>>>special_function\n"
+                "{\"arg1\": 1}\n",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2}));
+        assert_msg_equals(message_assist_call_python_lines,
+            common_chat_parse(
+                "python\n"
+                "# This is a program:\n"
+                "print('hey')",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2}));
+        assert_msg_equals(message_assist_call_python_lines_unclosed,
+            common_chat_parse(
+                "python\n"
+                "# This is a program:\n"
+                "print('hey')",
+                /* is_partial= */ true,
+                {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2}));
+        assert_msg_equals(message_assist_call,
+            common_chat_parse(
+                "special_function\n"
+                "{\"arg1\": 1} \n                    ",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2}));
+        assert_msg_equals(message_assist,
+            common_chat_parse(
+                "all\n"
+                "Hello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2}));
+
         test_templates(tmpls.get(), end_tokens, message_assist, {},
                       "all\n"
                       "Hello, world!\n"
@@ -869,22 +1197,73 @@ static void test_template_output_parsers() {
         auto tmpls = read_templates("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja");
         std::vector   end_tokens{ "<|end▁of▁sentence|>" };
 
-        assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1,                   common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
-        assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1,                   common_chat_templates_apply(tmpls.get(), inputs_tools).format);
-        assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
+        for (const auto & inputs : { inputs_no_tools, inputs_tools }) {
+            auto params = common_chat_templates_apply(tmpls.get(), inputs);
+            assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, params.format);
+            assert_equals(true, params.thinking_forced_open);
+        }
 
         test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
         test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
-        assert_msg_equals(message_assist_thoughts_unparsed_think,
-            common_chat_parse("I'm thinkingHello, world!\nWhat's up?",
-            COMMON_CHAT_FORMAT_DEEPSEEK_R1));
+        assert_msg_equals(
+            simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking"),
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ true,
+                }));
+        assert_msg_equals(
+            simple_assist_msg("", "I need to remember the correct syntax. It starts with <|tool▁calls▁begin|> and ends with"),
+            common_chat_parse(
+                "I need to remember the correct syntax. It starts with <|tool▁calls▁begin|> and ends with",
+                /* is_partial= */ true,
+                {
+                    COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ true,
+                }));
         assert_msg_equals(message_assist_thoughts,
-            common_chat_parse("I'm thinkingHello, world!\nWhat's up?",
-            COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
+        assert_msg_equals(message_assist_thoughts_unopened_unparsed,
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
+        assert_msg_equals(message_assist_thoughts,
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ true,
+                }));
         assert_msg_equals(message_assist_thoughts,
             // Latest template update (ast of 20250209) adds a trailing \n if add_generation_prompt is true.
-            common_chat_parse("I'm thinkingHello, world!\nWhat's up?",
-            COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ true,
+                }));
         // test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
         //               "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
         //               "```json\n"
@@ -901,16 +1280,32 @@ static void test_template_output_parsers() {
 
         assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1,                   common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
         assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1,                   common_chat_templates_apply(tmpls.get(), inputs_tools).format);
-        assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING, common_chat_templates_apply(tmpls.get(), inputs_tools_think).format);
 
         test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
         test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
-        assert_msg_equals(message_assist_thoughts_unparsed_think,
-            common_chat_parse("I'm thinkingHello, world!\nWhat's up?",
-            COMMON_CHAT_FORMAT_DEEPSEEK_R1));
+        assert_msg_equals(message_assist_thoughts_unparsed_deepseek,
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_DEEPSEEK_R1}));
         assert_msg_equals(message_assist_thoughts,
-            common_chat_parse("I'm thinkingHello, world!\nWhat's up?",
-            COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
+        assert_msg_equals(message_assist_thoughts,
+            common_chat_parse(
+                "I'm\nthinkingHello, world!\nWhat's up?",
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                    /* .reasoning_in_content = */ false,
+                    /* .thinking_forced_open = */ true,
+                }));
 
         assert_msg_equals(message_assist_call_thoughts_unparsed,
             common_chat_parse(
@@ -919,7 +1314,17 @@ static void test_template_output_parsers() {
                 "```json\n"
                 "{\"arg1\": 1}\n"
                 "```<|tool▁call▁end|><|tool▁calls▁end|>",
-                COMMON_CHAT_FORMAT_DEEPSEEK_R1));
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_DEEPSEEK_R1}));
+        assert_msg_equals(message_assist_call,
+            common_chat_parse(
+                "<|tool▁calls|>function<|tool▁sep|>special_function\n"
+                "```json\n"
+                "{\"arg1\": 1}\n"
+                "```<|tool▁call▁end|><|tool▁calls▁end|>",
+                /* is_partial= */ false,
+                {COMMON_CHAT_FORMAT_DEEPSEEK_R1}));
+
         assert_msg_equals(message_assist_call_thoughts,
             common_chat_parse(
                 "I'm\nthinking\n\n"
@@ -927,7 +1332,11 @@ static void test_template_output_parsers() {
                 "```json\n"
                 "{\"arg1\": 1}\n"
                 "```<|tool▁call▁end|><|tool▁calls▁end|>",
-                COMMON_CHAT_FORMAT_DEEPSEEK_R1_EXTRACT_REASONING));
+                /* is_partial= */ false,
+                {
+                    /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1,
+                    /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
+                }));
         test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
                 "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n"
                 "```json\n"
@@ -936,7 +1345,93 @@ static void test_template_output_parsers() {
     }
 }
 
+static void test_msg_diffs_compute() {
+    printf("[%s]\n", __func__);
+    {
+        common_chat_msg msg1;
+
+        common_chat_msg msg2;
+        msg2.content = "Hello, world!";
+
+        common_chat_msg_diff diff;
+        diff.content_delta = "Hello, world!";
+
+        assert_equals(
+            {diff},
+            common_chat_msg_diff::compute_diffs(msg1, msg2));
+    }
+    {
+        common_chat_msg msg1;
+        msg1.content = "Hello,";
+
+        common_chat_msg msg2;
+        msg2.content = "Hello, world!";
+
+        common_chat_msg_diff diff;
+        diff.content_delta = " world!";
+
+        assert_equals(
+            {diff},
+            common_chat_msg_diff::compute_diffs(msg1, msg2));
+    }
+    {
+        common_chat_msg msg0;
+
+        common_chat_msg msg1;
+        msg1.tool_calls = { { "special_function", "{\"ar", /* .id = */ "123" } };
+
+        common_chat_msg msg2;
+        msg2.tool_calls = { { "special_function", "{\"arg1\": 1}", /* .id = */ "123" } };
+
+        common_chat_msg_diff diff01;
+        diff01.tool_call_index = 0;
+        diff01.tool_call_delta.name = "special_function";
+        diff01.tool_call_delta.id = "123";
+        diff01.tool_call_delta.arguments = "{\"ar";
+
+        assert_equals(
+            {diff01},
+            common_chat_msg_diff::compute_diffs(msg0, msg1));
+
+        common_chat_msg_diff diff12;
+        diff12.tool_call_index = 0;
+        // Note: neither id nor name change here.
+        diff12.tool_call_delta.arguments = "g1\": 1}";
+
+        assert_equals(
+            {diff12},
+            common_chat_msg_diff::compute_diffs(msg1, msg2));
+    }
+    {
+        common_chat_msg msg0;
+
+        common_chat_msg msg2;
+        msg2.tool_calls = {
+            { "f1", "{\"arg1\": 1}", /* .id = */ "123" },
+            { "f2", "{\"arg2\": 2}", /* .id = */ "222" },
+        };
+
+        common_chat_msg_diff diff1;
+        diff1.tool_call_index = 0;
+        diff1.tool_call_delta.name = "f1";
+        diff1.tool_call_delta.id = "123";
+        diff1.tool_call_delta.arguments = "{\"arg1\": 1}";
+
+        common_chat_msg_diff diff2;
+        diff2.tool_call_index = 1;
+        diff2.tool_call_delta.name = "f2";
+        diff2.tool_call_delta.id = "222";
+        diff2.tool_call_delta.arguments = "{\"arg2\": 2}";
+
+        assert_equals(
+            {diff1, diff2},
+            common_chat_msg_diff::compute_diffs(msg0, msg2));
+    }
+}
+
 int main(int argc, char ** argv) {
+    common_log_set_verbosity_thold(999);
+
     // try {
 #ifndef _WIN32
         if (argc > 1) {
@@ -969,6 +1464,7 @@ int main(int argc, char ** argv) {
         } else
 #endif
         {
+            test_msg_diffs_compute();
             test_msgs_oaicompat_json_conversion();
             test_tools_oaicompat_json_conversion();
             test_template_output_parsers();
diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/tests/test-gbnf-validator.cpp
similarity index 98%
rename from examples/gbnf-validator/gbnf-validator.cpp
rename to tests/test-gbnf-validator.cpp
index a610e6a0b..6547eec32 100644
--- a/examples/gbnf-validator/gbnf-validator.cpp
+++ b/tests/test-gbnf-validator.cpp
@@ -1,5 +1,5 @@
-#include "unicode.h"
-#include "llama-grammar.h"
+#include "../src/unicode.h"
+#include "../src/llama-grammar.h"
 
 #include 
 #include 
diff --git a/tests/test-gguf.cpp b/tests/test-gguf.cpp
index eaf572c66..3f0c312e2 100644
--- a/tests/test-gguf.cpp
+++ b/tests/test-gguf.cpp
@@ -16,6 +16,7 @@ constexpr int offset_has_data    = 3000;
 
 enum handcrafted_file_type {
     HANDCRAFTED_HEADER_BAD_MAGIC           =  10,
+    HANDCRAFTED_HEADER_BAD_VERSION_0       =  15,
     HANDCRAFTED_HEADER_BAD_VERSION_1       =  20,
     HANDCRAFTED_HEADER_BAD_VERSION_FUTURE  =  30,
     HANDCRAFTED_HEADER_BAD_N_TENSORS       =  40,
@@ -51,6 +52,7 @@ enum handcrafted_file_type {
 static std::string handcrafted_file_type_name(const enum handcrafted_file_type hft) {
     switch (hft) {
         case HANDCRAFTED_HEADER_BAD_MAGIC:           return "HEADER_BAD_MAGIC";
+        case HANDCRAFTED_HEADER_BAD_VERSION_0:       return "HEADER_BAD_VERSION_0";
         case HANDCRAFTED_HEADER_BAD_VERSION_1:       return "HEADER_BAD_VERSION_1";
         case HANDCRAFTED_HEADER_BAD_VERSION_FUTURE:  return "HEADER_BAD_VERSION_FUTURE";
         case HANDCRAFTED_HEADER_BAD_N_KV:            return "HEADER_BAD_N_KV";
@@ -171,7 +173,10 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft
         helper_write(file, GGUF_MAGIC, 4);
     }
 
-    if (hft == HANDCRAFTED_HEADER_BAD_VERSION_1) {
+    if (hft == HANDCRAFTED_HEADER_BAD_VERSION_0) {
+        const uint32_t version = 0;
+        helper_write(file, version);
+    } else if (hft == HANDCRAFTED_HEADER_BAD_VERSION_1) {
         const uint32_t version = 1;
         helper_write(file, version);
     } else if (hft == HANDCRAFTED_HEADER_BAD_VERSION_FUTURE) {
@@ -660,6 +665,7 @@ static std::pair test_handcrafted_file(const unsigned int seed) {
 
     const std::vector hfts = {
         HANDCRAFTED_HEADER_BAD_MAGIC,
+        HANDCRAFTED_HEADER_BAD_VERSION_0,
         HANDCRAFTED_HEADER_BAD_VERSION_1,
         HANDCRAFTED_HEADER_BAD_VERSION_FUTURE,
         HANDCRAFTED_HEADER_BAD_N_KV,
diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp
index 890608648..6d64f0737 100644
--- a/tests/test-grammar-integration.cpp
+++ b/tests/test-grammar-integration.cpp
@@ -2,10 +2,13 @@
 #undef NDEBUG
 #endif
 
-#include "unicode.h"
-#include "llama-grammar.h"
 #include "json-schema-to-grammar.h"
 
+#include "../src/unicode.h"
+#include "../src/llama-grammar.h"
+
+#include 
+
 #include 
 #include 
 #include 
diff --git a/tests/test-grammar-llguidance.cpp b/tests/test-grammar-llguidance.cpp
index 3c19220e1..566b039a0 100644
--- a/tests/test-grammar-llguidance.cpp
+++ b/tests/test-grammar-llguidance.cpp
@@ -2,7 +2,6 @@
 #    undef NDEBUG
 #endif
 
-#include "unicode.h"
 #include "sampling.h"
 
 #include 
@@ -84,7 +83,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
 
             fprintf(stderr,
                     "\n NOTE: Debug grammar file generated. To analyze this failure in detail, run the following "
-                    "command:     ./llama-gbnf-validator test-grammar-integration.grammar.gbnf "
+                    "command:     ./test-gbnf-validator test-grammar-integration.grammar.gbnf "
                     "test-grammar-integration.string.txt\n\n");
         } else {
             fprintf(stdout, "✅︎\n");
diff --git a/tests/test-grammar-parser.cpp b/tests/test-grammar-parser.cpp
index 259172d99..67821a2d5 100644
--- a/tests/test-grammar-parser.cpp
+++ b/tests/test-grammar-parser.cpp
@@ -3,7 +3,9 @@
 #endif
 
 #include "llama.h"
-#include "llama-grammar.h"
+
+// TODO: shold not include libllama sources
+#include "../src/llama-grammar.h"
 
 #include 
 
diff --git a/tests/test-json-partial.cpp b/tests/test-json-partial.cpp
new file mode 100644
index 000000000..bc136bece
--- /dev/null
+++ b/tests/test-json-partial.cpp
@@ -0,0 +1,237 @@
+#include "common.h"
+#include "json-partial.h"
+#include 
+#include 
+#include 
+
+template  static void assert_equals(const T & expected, const T & actual) {
+  if (expected != actual) {
+      std::cerr << "Expected: " << expected << std::endl;
+      std::cerr << "Actual: " << actual << std::endl;
+      std::cerr << std::flush;
+      throw std::runtime_error("Test failed");
+  }
+}
+
+static void test_json_healing() {
+  auto parse = [](const std::string & str) {
+      std::cerr << "# Parsing: " << str << '\n';
+      std::string::const_iterator it = str.begin();
+      const auto end = str.end();
+      common_json out;
+      std::string healing_marker = "$llama.cpp.json$";
+      if (common_json_parse(it, end, healing_marker, out)) {
+          auto dump = out.json.dump();
+          std::cerr << "Parsed: " << dump << '\n';
+          std::cerr << "Magic: " << out.healing_marker.json_dump_marker << '\n';
+          std::string result;
+          if (!out.healing_marker.json_dump_marker.empty()) {
+              auto i = dump.find(out.healing_marker.json_dump_marker);
+              if (i == std::string::npos) {
+                  throw std::runtime_error("Failed to find magic in dump " + dump + " (magic: " + out.healing_marker.json_dump_marker + ")");
+              }
+              result = dump.substr(0, i);
+          } else {
+            result = dump;
+          }
+          std::cerr << "Result: " << result << '\n';
+          if (string_starts_with(str, result)) {
+            std::cerr << "Failure!\n";
+          }
+        //   return dump;
+      } else {
+        throw std::runtime_error("Failed to parse: " + str);
+      }
+
+  };
+  auto parse_all = [&](const std::string & str) {
+      for (size_t i = 1; i < str.size(); i++) {
+          parse(str.substr(0, i));
+      }
+  };
+  parse_all("{\"a\": \"b\"}");
+  parse_all("{\"hey\": 1, \"ho\\\"ha\": [1]}");
+
+  parse_all("[{\"a\": \"b\"}]");
+
+  auto test = [&](const std::vector & inputs, const std::string & expected, const std::string & expected_marker) {
+      for (const auto & input : inputs) {
+        common_json out;
+        assert_equals(true, common_json_parse(input, "$foo", out));
+        assert_equals(expected, out.json.dump());
+        assert_equals(expected_marker, out.healing_marker.json_dump_marker);
+      }
+  };
+  // No healing needed:
+  test(
+    {
+      R"([{"a":"b"}, "y"])",
+    },
+    R"([{"a":"b"},"y"])",
+    ""
+  );
+  // Partial literals can't be healed:
+  test(
+    {
+      R"([1)",
+      R"([tru)",
+      R"([n)",
+      R"([nul)",
+      R"([23.2)",
+    },
+    R"(["$foo"])",
+    R"("$foo)"
+  );
+  test(
+    {
+      R"({"a": 1)",
+      R"({"a": tru)",
+      R"({"a": n)",
+      R"({"a": nul)",
+      R"({"a": 23.2)",
+    },
+    R"({"a":"$foo"})",
+    R"("$foo)"
+  );
+  test(
+    {
+      R"({)",
+    },
+    R"({"$foo":1})",
+    R"("$foo)"
+  );
+  test(
+    {
+      R"([)",
+    },
+    R"(["$foo"])",
+    R"("$foo)"
+  );
+  // Healing right after a full literal
+  test(
+    {
+      R"(1 )",
+    },
+    R"(1)",
+    ""
+  );
+  test(
+    {
+      R"(true)",
+      R"(true )",
+    },
+    R"(true)",
+    ""
+  );
+  test(
+    {
+      R"(null)",
+      R"(null )",
+    },
+    R"(null)",
+    ""
+  );
+  test(
+    {
+      R"([1 )",
+    },
+    R"([1,"$foo"])",
+    R"(,"$foo)"
+  );
+  test(
+    {
+      R"([{})",
+      R"([{} )",
+    },
+    R"([{},"$foo"])",
+    R"(,"$foo)"
+  );
+  test(
+    {
+      R"([true)",
+    },
+    // TODO: detect the true/false/null literal was complete
+    R"(["$foo"])",
+    R"("$foo)"
+  );
+  test(
+    {
+      R"([true )",
+    },
+    R"([true,"$foo"])",
+    R"(,"$foo)"
+  );
+  test(
+    {
+      R"([true,)",
+    },
+    R"([true,"$foo"])",
+    R"("$foo)"
+  );
+  // Test nesting
+  test(
+    {
+      R"([{"a": [{"b": [{)",
+    },
+    R"([{"a":[{"b":[{"$foo":1}]}]}])",
+    R"("$foo)"
+  );
+  test(
+    {
+      R"([{"a": [{"b": [)",
+    },
+    R"([{"a":[{"b":["$foo"]}]}])",
+    R"("$foo)"
+  );
+
+  test(
+    {
+      R"([{"a": "b"})",
+      R"([{"a": "b"} )",
+    },
+    R"([{"a":"b"},"$foo"])",
+    R"(,"$foo)"
+  );
+  test(
+    {
+      R"([{"a": "b"},)",
+      R"([{"a": "b"}, )",
+    },
+    R"([{"a":"b"},"$foo"])",
+    R"("$foo)"
+  );
+  test(
+    {
+      R"({ "code)",
+    },
+    R"({"code$foo":1})",
+    R"($foo)"
+  );
+  test(
+    {
+      R"({ "code\)",
+    },
+    R"({"code\\$foo":1})",
+    R"(\$foo)"
+  );
+  test(
+    {
+      R"({ "code")",
+    },
+    R"({"code":"$foo"})",
+    R"(:"$foo)"
+  );
+  test(
+    {
+      R"({ "key")",
+    },
+    R"({"key":"$foo"})",
+    R"(:"$foo)"
+  );
+}
+
+int main() {
+    test_json_healing();
+    std::cerr << "All tests passed.\n";
+    return 0;
+}
diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp
index 4d78e9142..78ee55e24 100755
--- a/tests/test-json-schema-to-grammar.cpp
+++ b/tests/test-json-schema-to-grammar.cpp
@@ -4,7 +4,9 @@
 
 #include "json-schema-to-grammar.h"
 
-#include "llama-grammar.h"
+#include "../src/llama-grammar.h"
+
+#include 
 
 #include 
 #include 
@@ -597,6 +599,22 @@ static void test_all(const std::string & lang, std::function
 #include 
diff --git a/tests/test-mtmd-c-api.c b/tests/test-mtmd-c-api.c
new file mode 100644
index 000000000..02e762e6a
--- /dev/null
+++ b/tests/test-mtmd-c-api.c
@@ -0,0 +1,63 @@
+#include 
+#include 
+
+#include "mtmd.h"
+
+int main(void) {
+    printf("\n\nTesting libmtmd C API...\n");
+    printf("--------\n\n");
+
+    struct mtmd_context_params params = mtmd_context_params_default();
+    printf("Default image marker: %s\n", params.image_marker);
+
+    mtmd_input_chunks * chunks = mtmd_test_create_input_chunks();
+
+    if (!chunks) {
+        fprintf(stderr, "Failed to create input chunks\n");
+        return 1;
+    }
+
+    size_t n_chunks = mtmd_input_chunks_size(chunks);
+    printf("Number of chunks: %zu\n", n_chunks);
+    assert(n_chunks > 0);
+
+    for (size_t i = 0; i < n_chunks; i++) {
+        const mtmd_input_chunk * chunk = mtmd_input_chunks_get(chunks, i);
+        assert(chunk != NULL);
+        enum mtmd_input_chunk_type type = mtmd_input_chunk_get_type(chunk);
+        printf("Chunk %zu type: %d\n", i, type);
+
+        if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
+            size_t n_tokens;
+            const llama_token * tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens);
+            printf("    Text chunk with %zu tokens\n", n_tokens);
+            assert(tokens != NULL);
+            assert(n_tokens > 0);
+            for (size_t j = 0; j < n_tokens; j++) {
+                assert(tokens[j] >= 0);
+                printf("    > Token %zu: %d\n", j, tokens[j]);
+            }
+
+        } else if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
+            const mtmd_image_tokens * image_tokens = mtmd_input_chunk_get_tokens_image(chunk);
+            size_t n_tokens = mtmd_image_tokens_get_n_tokens(image_tokens);
+            size_t nx = mtmd_image_tokens_get_nx(image_tokens);
+            size_t ny = mtmd_image_tokens_get_ny(image_tokens);
+            const char * id = mtmd_image_tokens_get_id(image_tokens);
+            assert(n_tokens > 0);
+            assert(nx > 0);
+            assert(ny > 0);
+            assert(id != NULL);
+            printf("    Image chunk with %zu tokens\n", n_tokens);
+            printf("    Image size: %zu x %zu\n", nx, ny);
+            printf("    Image ID: %s\n", id);
+        }
+    }
+
+    // Free the chunks
+    mtmd_input_chunks_free(chunks);
+
+    printf("\n\nDONE: test libmtmd C API...\n");
+
+    return 0;
+}
diff --git a/tests/test-opt.cpp b/tests/test-opt.cpp
index f90c92b4b..558f87721 100644
--- a/tests/test-opt.cpp
+++ b/tests/test-opt.cpp
@@ -57,7 +57,8 @@ static helper_ctx_data helper_get_ctx_data(
         enum ggml_opt_loss_type loss_type          = GGML_OPT_LOSS_TYPE_SUM) {
     std::vector datasets(ndata);
     for (int64_t ndata_shard = 1; ndata_shard <= ndata; ++ndata_shard) {
-        ggml_opt_dataset_t dataset = ggml_opt_dataset_init(ne_datapoint, ne_label, ndata, ndata_shard);
+        ggml_opt_dataset_t dataset = ggml_opt_dataset_init(
+            GGML_TYPE_F32, GGML_TYPE_F32, ne_datapoint, ne_label, ndata, ndata_shard);
 
         float * data   = ggml_get_data_f32(ggml_opt_dataset_data(  dataset));
         float * labels = ggml_get_data_f32(ggml_opt_dataset_labels(dataset));
@@ -74,7 +75,8 @@ static helper_ctx_data helper_get_ctx_data(
         datasets[ndata_shard-1] = dataset;
     }
 
-    ggml_opt_dataset_t dataset_unsupervised = ggml_opt_dataset_init(1, 0, ndata, /*ndata_shard =*/ 1);
+    ggml_opt_dataset_t dataset_unsupervised = ggml_opt_dataset_init(
+        GGML_TYPE_F32, GGML_TYPE_F32, 1, 0, ndata, /*ndata_shard =*/ 1);
 
     float * data = ggml_get_data_f32(ggml_opt_dataset_data(dataset_unsupervised));
 
@@ -113,7 +115,7 @@ static helper_ctx_data helper_get_ctx_data(
 
     struct ggml_tensor * weights = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1);
     ggml_set_name(weights, "weights");
-    ggml_set_param(ctx_static, weights);
+    ggml_set_param(weights);
 
     struct ggml_tensor * intermediary = ggml_add(ctx_compute, inputs, weights);
 
@@ -127,8 +129,11 @@ static helper_ctx_data helper_get_ctx_data(
     GGML_ASSERT(nbatch_logical % nbatch_physical == 0);
     const int32_t opt_period = nbatch_logical / nbatch_physical;
 
-    struct ggml_opt_params opt_params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type);
-    opt_params.opt_period = opt_period;
+    struct ggml_opt_params opt_params = ggml_opt_default_params(backend_sched, loss_type);
+    opt_params.ctx_compute = ctx_compute;
+    opt_params.inputs      = inputs;
+    opt_params.outputs     = outputs;
+    opt_params.opt_period  = opt_period;
     if (!optimizer_defaults) {
         opt_params.get_opt_pars = helper_get_test_opt_pars;
     }
@@ -264,8 +269,9 @@ static std::pair test_grad(ggml_backend_sched_t backend_sched, ggml_ba
 
     for (int idata = 0; idata < ndata; ++idata) {
         const float idataf = idata;
+        ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
         ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
-        ggml_opt_forward_backward(cd.opt_ctx, cd.result);
+        ggml_opt_eval(cd.opt_ctx, cd.result);
         ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata, 0, sizeof(float));
     }
 
@@ -334,8 +340,9 @@ static std::pair test_forward_backward(
     } else {
         for (int idata = 0; idata < ndata; ++idata) {
             const float idataf = idata;
+            ggml_opt_alloc(cd.opt_ctx, /*backward =*/ false);
             ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
-            ggml_opt_forward(cd.opt_ctx, cd.result);
+            ggml_opt_eval(cd.opt_ctx, cd.result);
             ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
         }
     }
@@ -367,7 +374,8 @@ static std::pair test_forward_backward(
     float w0;
     ggml_backend_tensor_get(cd.weights, &w0, 0, sizeof(float));
     for (int i = 0; i < 10; ++i) {
-        ggml_opt_forward_backward(cd.opt_ctx, nullptr);
+        ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
+        ggml_opt_eval(cd.opt_ctx, cd.result);
     }
     ggml_backend_tensor_set(cd.weights, &w0, 0, sizeof(float));
 
@@ -387,8 +395,9 @@ static std::pair test_forward_backward(
     } else {
         for (int idata = 0; idata < ndata; ++idata) {
             const float idataf = idata;
+            ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
             ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
-            ggml_opt_forward_backward(cd.opt_ctx, cd.result);
+            ggml_opt_eval(cd.opt_ctx, cd.result);
             ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
         }
     }
@@ -492,14 +501,16 @@ static std::pair test_idata_split(ggml_backend_sched_t backend_sched,
             int idata = 0;
             for (; idata < idata_split; ++idata) {
                 const float idataf = idata;
+                ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
                 ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
-                ggml_opt_forward_backward(cd.opt_ctx, cd.result);
+                ggml_opt_eval(cd.opt_ctx, cd.result);
                 ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
             }
             for (; idata < ndata; ++idata) {
                 const float idataf = idata;
+                ggml_opt_alloc(cd.opt_ctx, /*backward =*/ false);
                 ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
-                ggml_opt_forward(cd.opt_ctx, cd.result2);
+                ggml_opt_eval(cd.opt_ctx, cd.result2);
                 ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float));
             }
         }
@@ -573,7 +584,6 @@ static std::pair test_gradient_accumulation(
 
     struct helper_ctx_data cd = helper_get_ctx_data(
         backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false, /*nbatch_logical =*/ 6, nbatch_physical, loss_type);
-    struct ggml_tensor * loss = ggml_opt_loss(cd.opt_ctx);
 
     std::vector grad_history(ndata);
     for (int64_t idata = 0; idata < ndata; ++idata) {
@@ -584,15 +594,17 @@ static std::pair test_gradient_accumulation(
         if (nbatch_physical == 1) {
             for (int idata = 0; idata < ndata; ++idata) {
                 const float idataf = idata;
+                ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
                 ggml_backend_tensor_set(cd.inputs, &idataf, 0, 1*sizeof(float));
-                ggml_opt_forward_backward(cd.opt_ctx, cd.result);
+                ggml_opt_eval(cd.opt_ctx, cd.result);
                 ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata, 0, 1*sizeof(float));
             }
         } else if (nbatch_physical == 2) {
             for (int idata = 0; idata < ndata; idata += 2) {
                 const float idataf[2] = {float(idata + 0), float(idata + 1)};
+                ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
                 ggml_backend_tensor_set(cd.inputs, idataf, 0, 2*sizeof(float));
-                ggml_opt_forward_backward(cd.opt_ctx, cd.result);
+                ggml_opt_eval(cd.opt_ctx, cd.result);
 
                 grad_history[idata + 0] = 0.0f;
                 ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata + 1, 0, 1*sizeof(float));
@@ -617,7 +629,7 @@ static std::pair test_gradient_accumulation(
                 }
                 subtest_ok = subtest_ok && almost_equal(grad_history[1], 2.0, atol);
                 subtest_ok = subtest_ok && almost_equal(grad_history[3], 4.0, atol);
-                subtest_ok = subtest_ok && almost_equal(grad_history[5], 0.0, atol);
+                subtest_ok = subtest_ok && almost_equal(grad_history[5], 6.0, atol);
             } else if (loss_type == GGML_OPT_LOSS_TYPE_MEAN) {
                 if (nbatch_physical == 1) {
                     subtest_ok = subtest_ok && almost_equal(grad_history[0], 1.0/ndata, atol);
@@ -630,7 +642,7 @@ static std::pair test_gradient_accumulation(
                 }
                 subtest_ok = subtest_ok && almost_equal(grad_history[1], 2.0/ndata, atol);
                 subtest_ok = subtest_ok && almost_equal(grad_history[3], 4.0/ndata, atol);
-                subtest_ok = subtest_ok && almost_equal(grad_history[5], 0.0/ndata, atol);
+                subtest_ok = subtest_ok && almost_equal(grad_history[5], 6.0/ndata, atol);
             } else {
                 GGML_ASSERT(false);
             }
@@ -692,7 +704,8 @@ static std::pair test_regression(ggml_backend_sched_t backend_sched, g
     std::mt19937 gen(12345);
     std::normal_distribution nd{0.0f, 0.1f};
 
-    ggml_opt_dataset_t dataset = ggml_opt_dataset_init(1, 1, ndata_regression, ndata_regression);
+    ggml_opt_dataset_t dataset = ggml_opt_dataset_init(
+        GGML_TYPE_F32, GGML_TYPE_F32, 1, 1, ndata_regression, ndata_regression);
 
     float * data   = ggml_get_data_f32(ggml_opt_dataset_data(  dataset));
     float * labels = ggml_get_data_f32(ggml_opt_dataset_labels(dataset));
@@ -733,15 +746,14 @@ static std::pair test_regression(ggml_backend_sched_t backend_sched, g
 
     struct ggml_tensor * a = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1);
     ggml_set_name(a, "a");
-    ggml_set_param(ctx_static, a);
+    ggml_set_param(a);
 
     struct ggml_tensor * b = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1);
     ggml_set_name(b, "b");
-    ggml_set_param(ctx_static, b);
+    ggml_set_param(b);
 
     struct ggml_tensor * f = ggml_add(ctx_compute, ggml_mul(ctx_compute, x, a), b);
     ggml_set_name(f, "f");
-    ggml_set_param(ctx_static, f);
 
     ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx_static, backend);
     const float a0 = 1.0f;
@@ -853,7 +865,7 @@ int main(void) {
         backends_modded.insert(backends_modded.end(), backends.begin(), backends.end());
 
         ggml_backend_sched_t backend_sched = ggml_backend_sched_new(
-            backends_modded.data(), nullptr, backends_modded.size(), GGML_DEFAULT_GRAPH_SIZE, false);
+            backends_modded.data(), nullptr, backends_modded.size(), GGML_DEFAULT_GRAPH_SIZE, false, true);
 
         printf("Backend %zu/%zu: %s\n", i + 1, dev_count, ggml_backend_dev_name(devs[i]));
         printf("  Device description: %s\n", ggml_backend_dev_description(devs[i]));
diff --git a/examples/quantize-stats/quantize-stats.cpp b/tests/test-quantize-stats.cpp
similarity index 99%
rename from examples/quantize-stats/quantize-stats.cpp
rename to tests/test-quantize-stats.cpp
index dd07ab9b3..a284a1f0c 100644
--- a/examples/quantize-stats/quantize-stats.cpp
+++ b/tests/test-quantize-stats.cpp
@@ -1,8 +1,10 @@
 #include "ggml.h"
+#include "ggml-cpu.h"
 #include "llama.h"
-#include "llama-model.h"
 #include "common.h"
 
+#include "../src/llama-model.h"
+
 #include 
 #include 
 #include 
diff --git a/tests/test-regex-partial.cpp b/tests/test-regex-partial.cpp
new file mode 100644
index 000000000..ffad18978
--- /dev/null
+++ b/tests/test-regex-partial.cpp
@@ -0,0 +1,288 @@
+//  Tests common_regex (esp. its partial final matches support).
+
+#include "common.h"
+#include "regex-partial.h"
+
+#include 
+#include 
+#include 
+
+template  static void assert_equals(const T & expected, const T & actual) {
+    if (expected != actual) {
+        std::cerr << "Expected: " << expected << std::endl;
+        std::cerr << "  Actual: " << actual << std::endl;
+        std::cerr << std::flush;
+        throw std::runtime_error("Test failed");
+    }
+}
+
+struct test_case {
+    std::string pattern;
+    struct input_output {
+        std::string input;
+        common_regex_match output;
+    };
+    std::vector inputs_outputs;
+};
+
+static std::string common_regex_match_type_name(common_regex_match_type type) {
+    switch (type) {
+        case COMMON_REGEX_MATCH_TYPE_NONE:
+            return "COMMON_REGEX_MATCH_TYPE_NONE";
+        case COMMON_REGEX_MATCH_TYPE_PARTIAL:
+            return "COMMON_REGEX_MATCH_TYPE_PARTIAL";
+        case COMMON_REGEX_MATCH_TYPE_FULL:
+            return "COMMON_REGEX_MATCH_TYPE_FULL";
+    }
+    return "?";
+}
+
+static void test_regex() {
+    printf("[%s]\n", __func__);
+    auto test = [](const test_case & test_case) {
+        common_regex cr(test_case.pattern);
+        std::cout << "Testing pattern: /" << test_case.pattern << "/\n";
+        // std::cout << "    partial rev: " << cr.reversed_partial_pattern.str() << '\n';
+        for (const auto & input_output : test_case.inputs_outputs) {
+            std::cout << "  Input: " << input_output.input << '\n';
+            auto m = cr.search(input_output.input, 0);
+            if (m != input_output.output) {
+                auto match_to_str = [&](const std::optional & m) {
+                    std::ostringstream ss;
+                    if (m->type == COMMON_REGEX_MATCH_TYPE_NONE) {
+                        ss << "";
+                    } else {
+                        GGML_ASSERT(!input_output.output.groups.empty());
+                        std::vector parts;
+                        for (const auto & g : m->groups) {
+                            parts.push_back("{" + std::to_string(g.begin) + ", " + std::to_string(g.end) + "}");
+                        }
+                        ss << "{" << common_regex_match_type_name(m->type) << ", {" << string_join(parts, ", ") << "}}";
+                    }
+                    return ss.str();
+                };
+                std::cout << "    Expected: " << match_to_str(input_output.output) << '\n';
+                std::cout << "         Got: " << match_to_str(m) << '\n';
+                std::cout << " Inverted pattern: /" << regex_to_reversed_partial_regex(test_case.pattern) << "/\n";
+
+                throw std::runtime_error("Test failed");
+            }
+        }
+    };
+    test({
+        "a",
+        {
+            {"a", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
+            {"b", {COMMON_REGEX_MATCH_TYPE_NONE, {}}},
+            {"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
+            {"ba", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 2}}}},
+        }
+    });
+    test({
+        "abcd",
+        {
+            {"abcd", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
+            {"abcde", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
+            {"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
+            {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
+            {"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
+            {"d", {}},
+            {"bcd", {}},
+            {"cde", {}},
+            {"cd", {}},
+            {"yeah ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{5, 7}}}},
+            {"abbie", {}},
+            {"", {}},
+        }
+    });
+    test({
+        ".*?ab",
+        {
+            {"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
+            {"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
+            {"dab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
+            {"dabc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
+            {"da", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
+            {"d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
+        }
+    });
+    test({
+        "a.*?b",
+        {
+            {"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
+            {"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
+            {"a b", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
+            {"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
+            {"argh", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
+            {"d", {}},
+            {"b", {}},
+        }
+    });
+    test({
+        "ab(?:cd){2,4}ef",
+        {
+            // {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, 0, {}}},
+            {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
+            {"abcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
+            {"abcde", {}},
+            {"abcdef", {}},
+            {"abcdcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
+            {"abcdcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 7}}}},
+            {"abcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
+            {"abcdcdcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 12}}}},
+            {"abcdcdcdcdcdef", {}},
+            {"abcde", {}},
+            {"yea", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{2, 3}}}},
+        }
+    });
+    test({
+        "a(?:rte| pure )fact",
+        {
+            {"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
+            {"art", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
+            {"artefa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
+            {"fact", {}},
+            {"an arte", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{3, 7}}}},
+            {"artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
+            {"an artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{3, 11}}}},
+            {"a pure", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
+            {"a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 11}}}},
+            {"it's a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{5, 16}}}},
+            {"" , {}},
+            {"pure", {}},
+            {"pure fact", {}},
+        }
+    });
+    test({
+        "abc",
+        {
+            {" abcc", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 4}}}},
+            {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
+            {"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
+            {" ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{1, 3}}}},
+            {"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
+            {"b", {}},
+            {"c", {}},
+            {"", {}},
+        }
+    });
+
+    test({
+        "(?:abc)?\\s*def",
+        {
+            {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
+            {"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
+            {"abc ", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
+            {"abc d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
+            {"abc de", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
+            {"abc def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
+            {"abc defg", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
+            {"abc defgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
+            {"abcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
+            {"abcdefgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 6}}}},
+            {" d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
+            {"def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
+        }
+    });
+
+    test({
+        "a+b",
+        {
+            {"aaab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
+            {"aaa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
+            {"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
+        }
+    });
+
+    test({
+        "(?:"
+            "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
+            "("                          // match 2 (open_tag)
+                ""
+                "|"
+                "|"
+                "|"
+                "|"
+                "|"
+                "|"
+                "|"
+            ")?"
+            "(\\s*\\{\\s*\"name\"\\s*:)" // match 3 (named tool call)
+        ")"
+        "|]+)>"            // match 4 (function name)
+        "|", // match 5 (function name again)
+        {
+            {"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}, {54, 54}, {54, 54}, {0, 8}, {54, 54}, {54, 54}}}},
+            {" {\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 18}}}},
+            {"{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 17}}}},
+            {"Let's call something\n{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{21, 38}}}},
+            {"Ok then{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 24}}}},
+            {"{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
+            {"Ok then{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 13}}}},
+            {" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 20}, {66, 66}, {0, 11}, {11, 20}, {66, 66}, {66, 66}}}},
+            {" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 24}, {70, 70}, {0, 15}, {15, 24}, {70, 70}, {70, 70}}}},
+            {" {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 34}, {89, 89}, {89, 89}, {89, 89}, {89, 89}, {16, 32}}}},
+            {"", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 14}, {14, 14}, {14, 14}, {14, 14}, {10, 13}, {14, 14}}}},
+
+        }
+    });
+}
+
+static void test_regex_to_reversed_partial_regex() {
+    printf("[%s]\n", __func__);
+
+    assert_equals(
+        "((?:(?:c)?b)?a)[\\s\\S]*",
+        regex_to_reversed_partial_regex("abc"));
+
+    assert_equals(
+        "(a+)[\\s\\S]*",
+        regex_to_reversed_partial_regex("a+"));
+
+    assert_equals(
+        "(a*)[\\s\\S]*",
+        regex_to_reversed_partial_regex("a*"));
+
+    assert_equals(
+        "(a?)[\\s\\S]*",
+        regex_to_reversed_partial_regex("a?"));
+
+    assert_equals(
+        "([a-z])[\\s\\S]*",
+        regex_to_reversed_partial_regex("[a-z]"));
+
+    assert_equals(
+        "((?:\\w+)?[a-z])[\\s\\S]*",
+        regex_to_reversed_partial_regex("[a-z]\\w+"));
+
+    assert_equals(
+        "((?:a|b))[\\s\\S]*",
+        regex_to_reversed_partial_regex("(?:a|b)"));
+    assert_equals(
+        "((?:(?:(?:d)?c)?b)?a)[\\s\\S]*",
+        regex_to_reversed_partial_regex("abcd"));
+    assert_equals(
+        "((?:b)?a*)[\\s\\S]*", // TODO: ((?:b)?a*+).* ??
+        regex_to_reversed_partial_regex("a*b"));
+    assert_equals(
+        "((?:(?:b)?a)?.*)[\\s\\S]*",
+        regex_to_reversed_partial_regex(".*?ab"));
+    assert_equals(
+        "((?:(?:b)?.*)?a)[\\s\\S]*",
+        regex_to_reversed_partial_regex("a.*?b"));
+    assert_equals(
+        "((?:(?:d)?(?:(?:c)?b))?a)[\\s\\S]*",
+        regex_to_reversed_partial_regex("a(bc)d"));
+    assert_equals(
+        "((?:(?:(?:c)?b|(?:e)?d))?a)[\\s\\S]*",
+        regex_to_reversed_partial_regex("a(bc|de)"));
+    assert_equals(
+        "((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)[\\s\\S]*",
+        regex_to_reversed_partial_regex("ab{2,4}c"));
+}
+
+int main() {
+    test_regex_to_reversed_partial_regex();
+    test_regex();
+    std::cout << "All tests passed.\n";
+}
diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp
index f1f87d454..6300f25ca 100644
--- a/tests/test-sampling.cpp
+++ b/tests/test-sampling.cpp
@@ -98,7 +98,7 @@ static void test_top_p(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector
+#include 
+#include 
+#include "llama.h"
+#include "arg.h"
+#include "common.h"
+#include "log.h"
+#include "sampling.h"
+
+int main(int argc, char ** argv) {
+    common_params params;
+
+    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
+        return 1;
+    }
+
+    common_init();
+
+    llama_backend_init();
+    llama_numa_init(params.numa);
+
+    LOG_INF("%s\n", common_params_get_system_info(params).c_str());
+
+    //llama_log_set([](ggml_log_level level, const char * text, void * /*user_data*/) {
+    //    if (level == GGML_LOG_LEVEL_ERROR) {
+    //        common_log_add(common_log_main(), level, "%s", text);
+    //    }
+    //}, NULL);
+
+    auto cparams = common_context_params_to_llama(params);
+
+    int dev_count = ggml_backend_dev_count();
+    int gpu_dev_count = 0;
+    for (int i = 0; i < dev_count; ++i) {
+        auto * dev = ggml_backend_dev_get(i);
+        if (dev && ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
+            gpu_dev_count++;
+        }
+    }
+    const int num_models = gpu_dev_count + 1 + 1; // GPUs + 1 CPU model + 1 layer split
+    //const int num_models = std::max(1, gpu_dev_count);
+    const int num_contexts = std::max(1, params.n_parallel);
+
+    std::vector models;
+    std::vector threads;
+    std::atomic failed = false;
+
+    for (int m = 0; m < num_models; ++m) {
+        auto mparams = common_model_params_to_llama(params);
+
+        if (m < gpu_dev_count) {
+            mparams.split_mode = LLAMA_SPLIT_MODE_NONE;
+            mparams.main_gpu = m;
+        } else if (m == gpu_dev_count) {
+            mparams.split_mode = LLAMA_SPLIT_MODE_NONE;
+            mparams.main_gpu = -1; // CPU model
+        } else {
+            mparams.split_mode = LLAMA_SPLIT_MODE_LAYER;;
+        }
+
+        llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
+        if (model == NULL) {
+            LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str());
+            return 1;
+        }
+
+        models.emplace_back(model);
+    }
+
+    for  (int m = 0; m < num_models; ++m) {
+        auto * model = models[m].get();
+        for (int c = 0; c < num_contexts; ++c) {
+            threads.emplace_back([&, m, c, model]() {
+                LOG_INF("Creating context %d/%d for model %d/%d\n", c + 1, num_contexts, m + 1, num_models);
+
+                llama_context_ptr ctx { llama_init_from_model(model, cparams) };
+                if (ctx == NULL) {
+                    LOG_ERR("failed to create context\n");
+                    failed.store(true);
+                    return;
+                }
+
+                std::unique_ptr sampler { common_sampler_init(model, params.sampling), common_sampler_free };
+                if (sampler == NULL) {
+                    LOG_ERR("failed to create sampler\n");
+                    failed.store(true);
+                    return;
+                }
+
+                llama_batch batch = {};
+                {
+                    auto prompt = common_tokenize(ctx.get(), params.prompt, true);
+                    if (prompt.empty()) {
+                        LOG_ERR("failed to tokenize prompt\n");
+                        failed.store(true);
+                        return;
+                    }
+                    batch = llama_batch_get_one(prompt.data(), prompt.size());
+                    if (llama_decode(ctx.get(), batch)) {
+                        LOG_ERR("failed to decode prompt\n");
+                        failed.store(true);
+                        return;
+                    }
+                }
+
+                const auto * vocab = llama_model_get_vocab(model);
+                std::string result = params.prompt;
+
+                for (int i = 0; i < params.n_predict; i++) {
+                    llama_token token;
+                    if (batch.n_tokens > 0) {
+                        token = common_sampler_sample(sampler.get(), ctx.get(), batch.n_tokens - 1);
+                    } else {
+                        token = llama_vocab_bos(vocab);
+                    }
+
+                    result += common_token_to_piece(ctx.get(), token);
+
+                    if (llama_vocab_is_eog(vocab, token)) {
+                        break;
+                    }
+
+                    batch = llama_batch_get_one(&token, 1);
+                    if (llama_decode(ctx.get(), batch)) {
+                        LOG_ERR("Model %d/%d, Context %d/%d: failed to decode\n", m + 1, num_models, c + 1, num_contexts);
+                        failed.store(true);
+                        return;
+                    }
+                }
+
+                LOG_INF("Model %d/%d, Context %d/%d: %s\n\n", m + 1, num_models, c + 1, num_contexts, result.c_str());
+            });
+        }
+    }
+
+    for (auto & thread : threads) {
+        thread.join();
+    }
+
+    if (failed) {
+        LOG_ERR("One or more threads failed.\n");
+        return 1;
+    }
+
+    LOG_INF("All threads finished without errors.\n");
+    return 0;
+}
diff --git a/tests/test-tokenizer-1-bpe.cpp b/tests/test-tokenizer-1-bpe.cpp
index 55425d88a..b183da47f 100644
--- a/tests/test-tokenizer-1-bpe.cpp
+++ b/tests/test-tokenizer-1-bpe.cpp
@@ -1,8 +1,9 @@
 #include "llama.h"
 #include "common.h"
-#include "unicode.h"
 #include "console.h"
 
+#include "../src/unicode.h"
+
 #include 
 #include 
 #include 
diff --git a/tests/test-tokenizer-1-spm.cpp b/tests/test-tokenizer-1-spm.cpp
index 9e7b77f31..ba6e94ba8 100644
--- a/tests/test-tokenizer-1-spm.cpp
+++ b/tests/test-tokenizer-1-spm.cpp
@@ -1,8 +1,9 @@
 #include "llama.h"
 #include "common.h"
-#include "unicode.h"
 #include "console.h"
 
+#include "../src/unicode.h"
+
 #include 
 #include 
 #include 
diff --git a/tests/test-tokenizers-repo.sh b/tests/test-tokenizers-repo.sh
new file mode 100755
index 000000000..86e839133
--- /dev/null
+++ b/tests/test-tokenizers-repo.sh
@@ -0,0 +1,36 @@
+#!/bin/bash
+
+if [ $# -lt 2 ]; then
+    printf "Usage: $0   []\n"
+    exit 1
+fi
+
+if [ $# -eq 3 ]; then
+    toktest=$3
+else
+    toktest="./test-tokenizer-0"
+fi
+
+if [ ! -x $toktest ]; then
+    printf "Test executable \"$toktest\" not found!\n"
+    exit 1
+fi
+
+repo=$1
+folder=$2
+
+if [ -d $folder ] && [ -d $folder/.git ]; then
+    (cd $folder; git pull)
+else
+    git clone $repo $folder
+fi
+
+shopt -s globstar
+for gguf in $folder/**/*.gguf; do
+    if [ -f $gguf.inp ] && [ -f $gguf.out ]; then
+        $toktest $gguf
+    else
+        printf "Found \"$gguf\" without matching inp/out files, ignoring...\n"
+    fi
+done
+
diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt
new file mode 100644
index 000000000..d64956b84
--- /dev/null
+++ b/tools/CMakeLists.txt
@@ -0,0 +1,39 @@
+# dependencies
+
+find_package(Threads REQUIRED)
+
+# third-party
+
+# ...
+
+# flags
+
+llama_add_compile_flags()
+
+# tools
+
+if (EMSCRIPTEN)
+else()
+    add_subdirectory(batched-bench)
+    add_subdirectory(gguf-split)
+    add_subdirectory(imatrix)
+    add_subdirectory(llama-bench)
+    add_subdirectory(main)
+    add_subdirectory(perplexity)
+    add_subdirectory(quantize)
+    if (LLAMA_BUILD_SERVER)
+        add_subdirectory(server)
+    endif()
+    add_subdirectory(run)
+    add_subdirectory(tokenize)
+    add_subdirectory(tts)
+    add_subdirectory(mtmd)
+    if (GGML_RPC)
+        add_subdirectory(rpc)
+    endif()
+    if (NOT GGML_BACKEND_DL)
+        # these examples use the backends directly and cannot be built with dynamic loading
+        add_subdirectory(cvector-generator)
+        add_subdirectory(export-lora)
+    endif()
+endif()
diff --git a/examples/batched-bench/CMakeLists.txt b/tools/batched-bench/CMakeLists.txt
similarity index 100%
rename from examples/batched-bench/CMakeLists.txt
rename to tools/batched-bench/CMakeLists.txt
diff --git a/examples/batched-bench/README.md b/tools/batched-bench/README.md
similarity index 100%
rename from examples/batched-bench/README.md
rename to tools/batched-bench/README.md
diff --git a/examples/batched-bench/batched-bench.cpp b/tools/batched-bench/batched-bench.cpp
similarity index 96%
rename from examples/batched-bench/batched-bench.cpp
rename to tools/batched-bench/batched-bench.cpp
index 0f4019293..a0a2e5ac5 100644
--- a/examples/batched-bench/batched-bench.cpp
+++ b/tools/batched-bench/batched-bench.cpp
@@ -57,6 +57,8 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
+    auto * mem = llama_get_memory(ctx);
+
     const int32_t n_kv_max = llama_n_ctx(ctx);
 
     llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
@@ -123,8 +125,8 @@ int main(int argc, char ** argv) {
 
                 common_batch_clear(batch);
 
-                for (int i = 0; i < pp; ++i) {
-                    for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
+                for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
+                    for (int i = 0; i < pp; ++i) {
                         common_batch_add(batch, 0, i, { j }, false);
                     }
                 }
@@ -132,7 +134,7 @@ int main(int argc, char ** argv) {
 
                 const auto t_pp_start = ggml_time_us();
 
-                llama_kv_self_clear(ctx);
+                llama_memory_clear(mem, false);
 
                 if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
                     LOG_ERR("%s: llama_decode() failed\n", __func__);
@@ -141,7 +143,7 @@ int main(int argc, char ** argv) {
 
                 if (is_pp_shared) {
                     for (int32_t i = 1; i < pl; ++i) {
-                        llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
+                        llama_memory_seq_cp(mem, 0, i, -1, -1);
                     }
                 }
 
diff --git a/examples/cvector-generator/CMakeLists.txt b/tools/cvector-generator/CMakeLists.txt
similarity index 100%
rename from examples/cvector-generator/CMakeLists.txt
rename to tools/cvector-generator/CMakeLists.txt
diff --git a/examples/cvector-generator/README.md b/tools/cvector-generator/README.md
similarity index 100%
rename from examples/cvector-generator/README.md
rename to tools/cvector-generator/README.md
diff --git a/examples/cvector-generator/completions.txt b/tools/cvector-generator/completions.txt
similarity index 100%
rename from examples/cvector-generator/completions.txt
rename to tools/cvector-generator/completions.txt
diff --git a/examples/cvector-generator/cvector-generator.cpp b/tools/cvector-generator/cvector-generator.cpp
similarity index 99%
rename from examples/cvector-generator/cvector-generator.cpp
rename to tools/cvector-generator/cvector-generator.cpp
index 2a9071550..d2d97e05c 100644
--- a/examples/cvector-generator/cvector-generator.cpp
+++ b/tools/cvector-generator/cvector-generator.cpp
@@ -342,7 +342,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
 }
 
 static bool get_hidden_layers(llama_context * ctx, std::vector & tokens) {
-    llama_kv_self_clear(ctx);
+    llama_memory_clear(llama_get_memory(ctx), true);
     if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
         fprintf(stderr, "%s : failed to eval\n", __func__);
         return false;
diff --git a/examples/cvector-generator/mean.hpp b/tools/cvector-generator/mean.hpp
similarity index 100%
rename from examples/cvector-generator/mean.hpp
rename to tools/cvector-generator/mean.hpp
diff --git a/examples/cvector-generator/negative.txt b/tools/cvector-generator/negative.txt
similarity index 100%
rename from examples/cvector-generator/negative.txt
rename to tools/cvector-generator/negative.txt
diff --git a/examples/cvector-generator/pca.hpp b/tools/cvector-generator/pca.hpp
similarity index 100%
rename from examples/cvector-generator/pca.hpp
rename to tools/cvector-generator/pca.hpp
diff --git a/examples/cvector-generator/positive.txt b/tools/cvector-generator/positive.txt
similarity index 100%
rename from examples/cvector-generator/positive.txt
rename to tools/cvector-generator/positive.txt
diff --git a/examples/export-lora/CMakeLists.txt b/tools/export-lora/CMakeLists.txt
similarity index 100%
rename from examples/export-lora/CMakeLists.txt
rename to tools/export-lora/CMakeLists.txt
diff --git a/examples/export-lora/README.md b/tools/export-lora/README.md
similarity index 100%
rename from examples/export-lora/README.md
rename to tools/export-lora/README.md
diff --git a/examples/export-lora/export-lora.cpp b/tools/export-lora/export-lora.cpp
similarity index 100%
rename from examples/export-lora/export-lora.cpp
rename to tools/export-lora/export-lora.cpp
diff --git a/examples/gguf-split/CMakeLists.txt b/tools/gguf-split/CMakeLists.txt
similarity index 100%
rename from examples/gguf-split/CMakeLists.txt
rename to tools/gguf-split/CMakeLists.txt
diff --git a/examples/gguf-split/README.md b/tools/gguf-split/README.md
similarity index 100%
rename from examples/gguf-split/README.md
rename to tools/gguf-split/README.md
diff --git a/examples/gguf-split/gguf-split.cpp b/tools/gguf-split/gguf-split.cpp
similarity index 100%
rename from examples/gguf-split/gguf-split.cpp
rename to tools/gguf-split/gguf-split.cpp
diff --git a/examples/gguf-split/tests.sh b/tools/gguf-split/tests.sh
similarity index 100%
rename from examples/gguf-split/tests.sh
rename to tools/gguf-split/tests.sh
diff --git a/examples/imatrix/CMakeLists.txt b/tools/imatrix/CMakeLists.txt
similarity index 100%
rename from examples/imatrix/CMakeLists.txt
rename to tools/imatrix/CMakeLists.txt
diff --git a/examples/imatrix/README.md b/tools/imatrix/README.md
similarity index 98%
rename from examples/imatrix/README.md
rename to tools/imatrix/README.md
index 9aa2b2034..6d8897d98 100644
--- a/examples/imatrix/README.md
+++ b/tools/imatrix/README.md
@@ -1,4 +1,4 @@
-# llama.cpp/examples/imatrix
+# llama.cpp/tools/imatrix
 
 Compute an importance matrix for a model and given text dataset. Can be used during quantization to enhance the quality of the quantized models.
 More information is available here: https://github.com/ggml-org/llama.cpp/pull/4861
diff --git a/examples/imatrix/imatrix.cpp b/tools/imatrix/imatrix.cpp
similarity index 98%
rename from examples/imatrix/imatrix.cpp
rename to tools/imatrix/imatrix.cpp
index fdbc97a5e..4250f507c 100644
--- a/examples/imatrix/imatrix.cpp
+++ b/tools/imatrix/imatrix.cpp
@@ -26,7 +26,8 @@ static void print_usage(int, char ** argv) {
     LOG("\n    %s \\\n"
             "       -m model.gguf -f some-text.txt [-o imatrix.gguf] [--process-output] \\\n"
             "       [--no-ppl] [--chunk 123] [--output-frequency 10] [--save-frequency 0] \\\n"
-            "       [--in-file imatrix-prev-0.gguf --in-file imatrix-prev-1.gguf ...]\n" , argv[0]);
+            "       [--in-file imatrix-prev-0.gguf --in-file imatrix-prev-1.gguf ...] \\\n"
+            "       [--parse-special]\n" , argv[0]);
     LOG("\n");
 }
 
@@ -66,7 +67,7 @@ private:
     std::mutex                             m_mutex;
     std::vector               m_datasets;
     int32_t                                m_last_chunk = 0;
-    std::vector                     m_src1_data;
+    std::vector                      m_src1_data;
     std::vector                      m_ids; // the expert ids from ggml_mul_mat_id
 };
 
@@ -115,11 +116,13 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
     const bool is_host = ggml_backend_buffer_is_host(src1->buffer);
 
     if (!is_host) {
-        m_src1_data.resize(ggml_nelements(src1));
-        ggml_backend_tensor_get(src1, m_src1_data.data(), 0, ggml_nbytes(src1));
+        const size_t src1_nbytes = ggml_nbytes(src1);
+        m_src1_data.resize(src1_nbytes);
+        ggml_backend_tensor_get(src1, m_src1_data.data(), 0, src1_nbytes);
     }
 
-    const float * data = is_host ? (const float *) src1->data : m_src1_data.data();
+    const char * data = is_host ? (const char *) src1->data : m_src1_data.data();
+    GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
 
     // this has been adapted to the new format of storing merged experts in a single 3d tensor
     // ref: https://github.com/ggml-org/llama.cpp/pull/6387
@@ -172,7 +175,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
 
                     const int64_t i11 = idx % src1->ne[1];
                     const int64_t i12 = row;
-                    const float * x = (const float *)((const char *)data + i11*src1->nb[1] + i12*src1->nb[2]);
+                    const float * x = (const float *)(data + i11*src1->nb[1] + i12*src1->nb[2]);
 
                     e.counts[ex]++;
 
@@ -214,7 +217,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
         LOG_DBGV(2, "%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_chunk, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type);
         // TODO: higher dimensions
         for (int row = 0; row < (int)src1->ne[1]; ++row) {
-            const float * x = data + row * src1->ne[0];
+            const float * x = (const float *) (data + row * src1->nb[1]);
             e.counts[0]++;
             for (int j = 0; j < (int)src1->ne[0]; ++j) {
                 e.values[j] = std::fma(x[j], x[j], e.values[j]);
@@ -729,7 +732,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c
     auto tim1 = std::chrono::high_resolution_clock::now();
     LOG_INF("%s: tokenizing the input ..\n", __func__);
 
-    std::vector tokens = common_tokenize(ctx, params.prompt, true);
+    std::vector tokens = common_tokenize(ctx, params.prompt, true, params.parse_special);
 
     auto tim2 = std::chrono::high_resolution_clock::now();
     LOG_INF("%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast(tim2-tim1).count());
@@ -793,7 +796,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c
         const auto t_start = std::chrono::high_resolution_clock::now();
 
         // clear the KV cache
-        llama_kv_self_clear(ctx);
+        llama_memory_clear(llama_get_memory(ctx), true);
 
         for (int j = 0; j < num_batches; ++j) {
             const int batch_start = start + j * n_batch;
@@ -899,7 +902,6 @@ int main(int argc, char ** argv) {
     params.out_file = "imatrix.gguf";
 
     params.n_ctx = 512;
-    params.logits_all = true;
     params.escape = false;
 
     if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_IMATRIX, print_usage)) {
diff --git a/examples/llama-bench/CMakeLists.txt b/tools/llama-bench/CMakeLists.txt
similarity index 100%
rename from examples/llama-bench/CMakeLists.txt
rename to tools/llama-bench/CMakeLists.txt
diff --git a/examples/llama-bench/README.md b/tools/llama-bench/README.md
similarity index 54%
rename from examples/llama-bench/README.md
rename to tools/llama-bench/README.md
index 6bbe4bb75..31a273087 100644
--- a/examples/llama-bench/README.md
+++ b/tools/llama-bench/README.md
@@ -1,4 +1,4 @@
-# llama.cpp/examples/llama-bench
+# llama.cpp/tools/llama-bench
 
 Performance testing tool for llama.cpp.
 
@@ -20,40 +20,50 @@ Performance testing tool for llama.cpp.
 ## Syntax
 
 ```
-usage: ./llama-bench [options]
+usage: llama-bench [options]
 
 options:
   -h, --help
+  --numa        numa mode (default: disabled)
+  -r, --repetitions                      number of times to repeat each test (default: 5)
+  --prio <0|1|2|3>                          process/thread priority (default: 0)
+  --delay <0...N> (seconds)                 delay between each test (default: 0)
+  -o, --output       output format printed to stdout (default: md)
+  -oe, --output-err  output format printed to stderr (default: none)
+  -v, --verbose                             verbose output
+  --progress                                print test progress indicators
+
+test parameters:
   -m, --model                     (default: models/7B/ggml-model-q4_0.gguf)
   -p, --n-prompt                         (default: 512)
   -n, --n-gen                            (default: 128)
   -pg                                (default: )
+  -d, --n-depth                          (default: 0)
   -b, --batch-size                       (default: 2048)
   -ub, --ubatch-size                     (default: 512)
   -ctk, --cache-type-k                   (default: f16)
   -ctv, --cache-type-v                   (default: f16)
-  -t, --threads                          (default: 8)
+  -dt, --defrag-thold                    (default: -1)
+  -t, --threads                          (default: system dependent)
   -C, --cpu-mask                   (default: 0x0)
   --cpu-strict <0|1>                        (default: 0)
   --poll <0...100>                          (default: 50)
   -ngl, --n-gpu-layers                   (default: 99)
-  -rpc, --rpc                  (default: )
+  -rpc, --rpc                  (default: none)
   -sm, --split-mode         (default: layer)
   -mg, --main-gpu                        (default: 0)
   -nkvo, --no-kv-offload <0|1>              (default: 0)
   -fa, --flash-attn <0|1>                   (default: 0)
   -mmp, --mmap <0|1>                        (default: 1)
-  --numa        (default: disabled)
   -embd, --embeddings <0|1>                 (default: 0)
   -ts, --tensor-split           (default: 0)
-  -r, --repetitions                      (default: 5)
-  --prio <0|1|2|3>                          (default: 0)
-  --delay <0...N> (seconds)                 (default: 0)
-  -o, --output       (default: md)
-  -oe, --output-err  (default: none)
-  -v, --verbose                             (default: 0)
+  -ot --override-tensors =;...
+                                            (default: disabled)
+  -nopo, --no-op-offload <0|1>              (default: 0)
 
-Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.
+Multiple values can be given for each parameter by separating them with ','
+or by specifying the parameter multiple times. Ranges can be given as
+'first-last' or 'first-last+step' or 'first-last*mult'.
 ```
 
 llama-bench can perform three types of tests:
@@ -66,12 +76,10 @@ With the exception of `-r`, `-o` and `-v`, all options can be specified multiple
 
 Each test is repeated the number of times given by `-r`, and the results are averaged. The results are given in average tokens per second (t/s) and standard deviation. Some output formats (e.g. json) also include the individual results of each repetition.
 
+Using the `-d ` option, each test can be run at a specified context depth, prefilling the KV cache with `` tokens.
+
 For a description of the other options, see the [main example](../main/README.md).
 
-Note:
-
-- When using SYCL backend, there would be hang issue in some cases. Please set `--mmp 0`.
-
 ## Examples
 
 ### Text generation with different models
@@ -148,6 +156,19 @@ $ ./llama-bench -ngl 10,20,30,31,32,33,34,35
 | llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  35 | pp 512     |   2400.01 ± 7.72 |
 | llama 7B mostly Q4_0           |   3.56 GiB |     6.74 B | CUDA       |  35 | tg 128     |    131.66 ± 0.49 |
 
+### Different prefilled context
+
+```
+$ ./llama-bench -d 0,512
+```
+
+| model                          |       size |     params | backend    | ngl |            test |                  t/s |
+| ------------------------------ | ---------: | ---------: | ---------- | --: | --------------: | -------------------: |
+| qwen2 7B Q4_K - Medium         |   4.36 GiB |     7.62 B | CUDA       |  99 |           pp512 |      7340.20 ± 23.45 |
+| qwen2 7B Q4_K - Medium         |   4.36 GiB |     7.62 B | CUDA       |  99 |           tg128 |        120.60 ± 0.59 |
+| qwen2 7B Q4_K - Medium         |   4.36 GiB |     7.62 B | CUDA       |  99 |    pp512 @ d512 |      6425.91 ± 18.88 |
+| qwen2 7B Q4_K - Medium         |   4.36 GiB |     7.62 B | CUDA       |  99 |    tg128 @ d512 |        116.71 ± 0.60 |
+
 ## Output formats
 
 By default, llama-bench outputs the results in markdown format. The results can be output in other formats by using the `-o` option.
@@ -170,9 +191,9 @@ $ ./llama-bench -o csv
 ```
 
 ```csv
-build_commit,build_number,cuda,metal,gpu_blas,blas,cpu_info,gpu_info,model_filename,model_type,model_size,model_n_params,n_batch,n_threads,f16_kv,n_gpu_layers,main_gpu,mul_mat_q,tensor_split,n_prompt,n_gen,test_time,avg_ns,stddev_ns,avg_ts,stddev_ts
-"3469684","1275","1","0","0","1","1","13th Gen Intel(R) Core(TM) i9-13900K","NVIDIA GeForce RTX 3090 Ti","models/7B/ggml-model-q4_0.gguf","llama 7B mostly Q4_0","3825065984","6738415616","512","16","1","99","0","1","0.00","512","0","2023-09-23T12:09:01Z","212155977","732372","2413.341687","8.305961"
-"3469684","1275","1","0","0","1","1","13th Gen Intel(R) Core(TM) i9-13900K","NVIDIA GeForce RTX 3090 Ti","models/7B/ggml-model-q4_0.gguf","llama 7B mostly Q4_0","3825065984","6738415616","512","16","1","99","0","1","0.00","0","128","2023-09-23T12:09:02Z","969320879","2728399","132.052051","0.371342"
+build_commit,build_number,cpu_info,gpu_info,backends,model_filename,model_type,model_size,model_n_params,n_batch,n_ubatch,n_threads,cpu_mask,cpu_strict,poll,type_k,type_v,n_gpu_layers,split_mode,main_gpu,no_kv_offload,flash_attn,tensor_split,use_mmap,embeddings,n_prompt,n_gen,n_depth,test_time,avg_ns,stddev_ns,avg_ts,stddev_ts
+"8cf427ff","5163","AMD Ryzen 7 7800X3D 8-Core Processor","NVIDIA GeForce RTX 4080","CUDA","models/Qwen2.5-7B-Instruct-Q4_K_M.gguf","qwen2 7B Q4_K - Medium","4677120000","7615616512","2048","512","8","0x0","0","50","f16","f16","99","layer","0","0","0","0.00","1","0","512","0","0","2025-04-24T11:57:09Z","70285660","982040","7285.676949","100.064434"
+"8cf427ff","5163","AMD Ryzen 7 7800X3D 8-Core Processor","NVIDIA GeForce RTX 4080","CUDA","models/Qwen2.5-7B-Instruct-Q4_K_M.gguf","qwen2 7B Q4_K - Medium","4677120000","7615616512","2048","512","8","0x0","0","50","f16","f16","99","layer","0","0","0","0.00","1","0","0","128","0","2025-04-24T11:57:10Z","1067431600","3834831","119.915244","0.430617"
 ```
 
 ### JSON
@@ -184,64 +205,78 @@ $ ./llama-bench -o json
 ```json
 [
   {
-    "build_commit": "3469684",
-    "build_number": 1275,
-    "cuda": true,
-    "metal": false,
-    "gpu_blas": true,
-    "blas": true,
-    "cpu_info": "13th Gen Intel(R) Core(TM) i9-13900K",
-    "gpu_info": "NVIDIA GeForce RTX 3090 Ti",
-    "model_filename": "models/7B/ggml-model-q4_0.gguf",
-    "model_type": "llama 7B mostly Q4_0",
-    "model_size": 3825065984,
-    "model_n_params": 6738415616,
-    "n_batch": 512,
-    "n_threads": 16,
-    "f16_kv": true,
+    "build_commit": "8cf427ff",
+    "build_number": 5163,
+    "cpu_info": "AMD Ryzen 7 7800X3D 8-Core Processor",
+    "gpu_info": "NVIDIA GeForce RTX 4080",
+    "backends": "CUDA",
+    "model_filename": "models/Qwen2.5-7B-Instruct-Q4_K_M.gguf",
+    "model_type": "qwen2 7B Q4_K - Medium",
+    "model_size": 4677120000,
+    "model_n_params": 7615616512,
+    "n_batch": 2048,
+    "n_ubatch": 512,
+    "n_threads": 8,
+    "cpu_mask": "0x0",
+    "cpu_strict": false,
+    "poll": 50,
+    "type_k": "f16",
+    "type_v": "f16",
     "n_gpu_layers": 99,
+    "split_mode": "layer",
     "main_gpu": 0,
-    "mul_mat_q": true,
+    "no_kv_offload": false,
+    "flash_attn": false,
     "tensor_split": "0.00",
+    "use_mmap": true,
+    "embeddings": false,
     "n_prompt": 512,
     "n_gen": 0,
-    "test_time": "2023-09-23T12:09:57Z",
-    "avg_ns": 212365953,
-    "stddev_ns": 985423,
-    "avg_ts": 2410.974041,
-    "stddev_ts": 11.163766,
-    "samples_ns": [ 213837238, 211635853, 212328053, 211329715, 212698907 ],
-    "samples_ts": [ 2394.34, 2419.25, 2411.36, 2422.75, 2407.16 ]
+    "n_depth": 0,
+    "test_time": "2025-04-24T11:58:50Z",
+    "avg_ns": 72135640,
+    "stddev_ns": 1453752,
+    "avg_ts": 7100.002165,
+    "stddev_ts": 140.341520,
+    "samples_ns": [ 74601900, 71632900, 71745200, 71952700, 70745500 ],
+    "samples_ts": [ 6863.1, 7147.55, 7136.37, 7115.79, 7237.21 ]
   },
   {
-    "build_commit": "3469684",
-    "build_number": 1275,
-    "cuda": true,
-    "metal": false,
-    "gpu_blas": true,
-    "blas": true,
-    "cpu_info": "13th Gen Intel(R) Core(TM) i9-13900K",
-    "gpu_info": "NVIDIA GeForce RTX 3090 Ti",
-    "model_filename": "models/7B/ggml-model-q4_0.gguf",
-    "model_type": "llama 7B mostly Q4_0",
-    "model_size": 3825065984,
-    "model_n_params": 6738415616,
-    "n_batch": 512,
-    "n_threads": 16,
-    "f16_kv": true,
+    "build_commit": "8cf427ff",
+    "build_number": 5163,
+    "cpu_info": "AMD Ryzen 7 7800X3D 8-Core Processor",
+    "gpu_info": "NVIDIA GeForce RTX 4080",
+    "backends": "CUDA",
+    "model_filename": "models/Qwen2.5-7B-Instruct-Q4_K_M.gguf",
+    "model_type": "qwen2 7B Q4_K - Medium",
+    "model_size": 4677120000,
+    "model_n_params": 7615616512,
+    "n_batch": 2048,
+    "n_ubatch": 512,
+    "n_threads": 8,
+    "cpu_mask": "0x0",
+    "cpu_strict": false,
+    "poll": 50,
+    "type_k": "f16",
+    "type_v": "f16",
     "n_gpu_layers": 99,
+    "split_mode": "layer",
     "main_gpu": 0,
-    "mul_mat_q": true,
+    "no_kv_offload": false,
+    "flash_attn": false,
     "tensor_split": "0.00",
+    "use_mmap": true,
+    "embeddings": false,
     "n_prompt": 0,
     "n_gen": 128,
-    "test_time": "2023-09-23T12:09:59Z",
-    "avg_ns": 977425219,
-    "stddev_ns": 9268593,
-    "avg_ts": 130.965708,
-    "stddev_ts": 1.238924,
-    "samples_ns": [ 984472709, 974901233, 989474741, 970729355, 967548060 ],
-    "samples_ts": [ 130.019, 131.295, 129.362, 131.86, 132.293 ]
+    "n_depth": 0,
+    "test_time": "2025-04-24T11:58:51Z",
+    "avg_ns": 1076767880,
+    "stddev_ns": 9449585,
+    "avg_ts": 118.881588,
+    "stddev_ts": 1.041811,
+    "samples_ns": [ 1075361300, 1065089400, 1071761200, 1081934900, 1089692600 ],
+    "samples_ts": [ 119.03, 120.178, 119.43, 118.307, 117.464 ]
   }
 ]
 ```
@@ -254,8 +289,8 @@ $ ./llama-bench -o jsonl
 ```
 
 ```json lines
-{"build_commit":"3469684","build_number":1275,"cuda":true,"metal":false,"gpu_blas":true,"blas":true,"cpu_info":"13th Gen Intel(R) Core(TM) i9-13900K","gpu_info":"NVIDIA GeForce RTX 3090 Ti","model_filename":"models/7B/ggml-model-q4_0.gguf","model_type":"llama 7B mostly Q4_0","model_size":3825065984,"model_n_params":6738415616,"n_batch":512,"n_threads":16,"f16_kv":true,"n_gpu_layers":99,"main_gpu":0,"mul_mat_q":true,"tensor_split":"0.00","n_prompt":512,"n_gen":0,"test_time":"2023-09-23T12:09:57Z","avg_ns":212365953,"stddev_ns":985423,"avg_ts":2410.974041,"stddev_ts":11.163766,"samples_ns":[213837238,211635853,212328053,211329715,212698907],"samples_ts":[2394.34,2419.25,2411.36,2422.75,2407.16]}
-{"build_commit":"3469684","build_number":1275,"cuda":true,"metal":false,"gpu_blas":true,"blas":true,"cpu_info":"13th Gen Intel(R) Core(TM) i9-13900K","gpu_info":"NVIDIA GeForce RTX 3090 Ti","model_filename":"models/7B/ggml-model-q4_0.gguf","model_type":"llama 7B mostly Q4_0","model_size":3825065984,"model_n_params":6738415616,"n_batch":512,"n_threads":16,"f16_kv":true,"n_gpu_layers":99,"main_gpu":0,"mul_mat_q":true,"tensor_split":"0.00","n_prompt":0,"n_gen":128,"test_time":"2023-09-23T12:09:59Z","avg_ns":977425219,"stddev_ns":9268593,"avg_ts":130.965708,"stddev_ts":1.238924,"samples_ns":[984472709,974901233,989474741,970729355,967548060],"samples_ts":[130.019,131.295,129.362,131.86,132.293]}
+{"build_commit": "8cf427ff", "build_number": 5163, "cpu_info": "AMD Ryzen 7 7800X3D 8-Core Processor", "gpu_info": "NVIDIA GeForce RTX 4080", "backends": "CUDA", "model_filename": "models/Qwen2.5-7B-Instruct-Q4_K_M.gguf", "model_type": "qwen2 7B Q4_K - Medium", "model_size": 4677120000, "model_n_params": 7615616512, "n_batch": 2048, "n_ubatch": 512, "n_threads": 8, "cpu_mask": "0x0", "cpu_strict": false, "poll": 50, "type_k": "f16", "type_v": "f16", "n_gpu_layers": 99, "split_mode": "layer", "main_gpu": 0, "no_kv_offload": false, "flash_attn": false, "tensor_split": "0.00", "use_mmap": true, "embeddings": false, "n_prompt": 512, "n_gen": 0, "n_depth": 0, "test_time": "2025-04-24T11:59:33Z", "avg_ns": 70497220, "stddev_ns": 883196, "avg_ts": 7263.609157, "stddev_ts": 90.940578, "samples_ns": [ 71551000, 71222800, 70364100, 69439100, 69909100 ],"samples_ts": [ 7155.74, 7188.71, 7276.44, 7373.37, 7323.8 ]}
+{"build_commit": "8cf427ff", "build_number": 5163, "cpu_info": "AMD Ryzen 7 7800X3D 8-Core Processor", "gpu_info": "NVIDIA GeForce RTX 4080", "backends": "CUDA", "model_filename": "models/Qwen2.5-7B-Instruct-Q4_K_M.gguf", "model_type": "qwen2 7B Q4_K - Medium", "model_size": 4677120000, "model_n_params": 7615616512, "n_batch": 2048, "n_ubatch": 512, "n_threads": 8, "cpu_mask": "0x0", "cpu_strict": false, "poll": 50, "type_k": "f16", "type_v": "f16", "n_gpu_layers": 99, "split_mode": "layer", "main_gpu": 0, "no_kv_offload": false, "flash_attn": false, "tensor_split": "0.00", "use_mmap": true, "embeddings": false, "n_prompt": 0, "n_gen": 128, "n_depth": 0, "test_time": "2025-04-24T11:59:33Z", "avg_ns": 1068078400, "stddev_ns": 6279455, "avg_ts": 119.844681, "stddev_ts": 0.699739, "samples_ns": [ 1066331700, 1064864900, 1079042600, 1063328400, 1066824400 ],"samples_ts": [ 120.038, 120.203, 118.624, 120.377, 119.982 ]}
 ```
 
 
@@ -271,25 +306,32 @@ $ ./llama-bench -o sql
 CREATE TABLE IF NOT EXISTS test (
   build_commit TEXT,
   build_number INTEGER,
-  cuda INTEGER,
-  metal INTEGER,
-  gpu_blas INTEGER,
-  blas INTEGER,
   cpu_info TEXT,
   gpu_info TEXT,
+  backends TEXT,
   model_filename TEXT,
   model_type TEXT,
   model_size INTEGER,
   model_n_params INTEGER,
   n_batch INTEGER,
+  n_ubatch INTEGER,
   n_threads INTEGER,
-  f16_kv INTEGER,
+  cpu_mask TEXT,
+  cpu_strict INTEGER,
+  poll INTEGER,
+  type_k TEXT,
+  type_v TEXT,
   n_gpu_layers INTEGER,
+  split_mode TEXT,
   main_gpu INTEGER,
-  mul_mat_q INTEGER,
+  no_kv_offload INTEGER,
+  flash_attn INTEGER,
   tensor_split TEXT,
+  use_mmap INTEGER,
+  embeddings INTEGER,
   n_prompt INTEGER,
   n_gen INTEGER,
+  n_depth INTEGER,
   test_time TEXT,
   avg_ns INTEGER,
   stddev_ns INTEGER,
@@ -297,6 +339,6 @@ CREATE TABLE IF NOT EXISTS test (
   stddev_ts REAL
 );
 
-INSERT INTO test (build_commit, build_number, cuda, metal, gpu_blas, blas, cpu_info, gpu_info, model_filename, model_type, model_size, model_n_params, n_batch, n_threads, f16_kv, n_gpu_layers, main_gpu, mul_mat_q, tensor_split, n_prompt, n_gen, test_time, avg_ns, stddev_ns, avg_ts, stddev_ts) VALUES ('3469684', '1275', '1', '0', '0', '1', '1', '13th Gen Intel(R) Core(TM) i9-13900K', 'NVIDIA GeForce RTX 3090 Ti', 'models/7B/ggml-model-q4_0.gguf', 'llama 7B mostly Q4_0', '3825065984', '6738415616', '512', '16', '1', '99', '0', '1', '0.00', '512', '0', '2023-09-23T12:10:30Z', '212693772', '743623', '2407.240204', '8.409634');
-INSERT INTO test (build_commit, build_number, cuda, metal, gpu_blas, blas, cpu_info, gpu_info, model_filename, model_type, model_size, model_n_params, n_batch, n_threads, f16_kv, n_gpu_layers, main_gpu, mul_mat_q, tensor_split, n_prompt, n_gen, test_time, avg_ns, stddev_ns, avg_ts, stddev_ts) VALUES ('3469684', '1275', '1', '0', '0', '1', '1', '13th Gen Intel(R) Core(TM) i9-13900K', 'NVIDIA GeForce RTX 3090 Ti', 'models/7B/ggml-model-q4_0.gguf', 'llama 7B mostly Q4_0', '3825065984', '6738415616', '512', '16', '1', '99', '0', '1', '0.00', '0', '128', '2023-09-23T12:10:31Z', '977925003', '4037361', '130.891159', '0.537692');
+INSERT INTO test (build_commit, build_number, cpu_info, gpu_info, backends, model_filename, model_type, model_size, model_n_params, n_batch, n_ubatch, n_threads, cpu_mask, cpu_strict, poll, type_k, type_v, n_gpu_layers, split_mode, main_gpu, no_kv_offload, flash_attn, tensor_split, use_mmap, embeddings, n_prompt, n_gen, n_depth, test_time, avg_ns, stddev_ns, avg_ts, stddev_ts) VALUES ('8cf427ff', '5163', 'AMD Ryzen 7 7800X3D 8-Core Processor', 'NVIDIA GeForce RTX 4080', 'CUDA', 'models/Qwen2.5-7B-Instruct-Q4_K_M.gguf', 'qwen2 7B Q4_K - Medium', '4677120000', '7615616512', '2048', '512', '8', '0x0', '0', '50', 'f16', 'f16', '99', 'layer', '0', '0', '0', '0.00', '1', '0', '512', '0', '0', '2025-04-24T12:00:08Z', '69905000', '519516', '7324.546977', '54.032613');
+INSERT INTO test (build_commit, build_number, cpu_info, gpu_info, backends, model_filename, model_type, model_size, model_n_params, n_batch, n_ubatch, n_threads, cpu_mask, cpu_strict, poll, type_k, type_v, n_gpu_layers, split_mode, main_gpu, no_kv_offload, flash_attn, tensor_split, use_mmap, embeddings, n_prompt, n_gen, n_depth, test_time, avg_ns, stddev_ns, avg_ts, stddev_ts) VALUES ('8cf427ff', '5163', 'AMD Ryzen 7 7800X3D 8-Core Processor', 'NVIDIA GeForce RTX 4080', 'CUDA', 'models/Qwen2.5-7B-Instruct-Q4_K_M.gguf', 'qwen2 7B Q4_K - Medium', '4677120000', '7615616512', '2048', '512', '8', '0x0', '0', '50', 'f16', 'f16', '99', 'layer', '0', '0', '0', '0.00', '1', '0', '0', '128', '0', '2025-04-24T12:00:09Z', '1063608780', '4464130', '120.346696', '0.504647');
 ```
diff --git a/examples/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp
similarity index 64%
rename from examples/llama-bench/llama-bench.cpp
rename to tools/llama-bench/llama-bench.cpp
index cbcbfcee8..e59d61f19 100644
--- a/examples/llama-bench/llama-bench.cpp
+++ b/tools/llama-bench/llama-bench.cpp
@@ -36,6 +36,46 @@ static uint64_t get_time_ns() {
     return std::chrono::nanoseconds(clock::now().time_since_epoch()).count();
 }
 
+static bool tensor_buft_override_equal(const llama_model_tensor_buft_override& a, const llama_model_tensor_buft_override& b) {
+    if (a.pattern != b.pattern) {
+        // cString comparison that may be null
+        if (a.pattern == nullptr || b.pattern == nullptr) {
+            return false;
+        }
+        if (strcmp(a.pattern, b.pattern) != 0) {
+            return false;
+        }
+    }
+    if (a.buft != b.buft) {
+        return false;
+    }
+    return true;
+}
+
+static bool vec_tensor_buft_override_equal(const std::vector& a, const std::vector& b) {
+    if (a.size() != b.size()) {
+        return false;
+    }
+    for (size_t i = 0; i < a.size(); i++) {
+        if (!tensor_buft_override_equal(a[i], b[i])) {
+            return false;
+        }
+    }
+    return true;
+}
+
+static bool vec_vec_tensor_buft_override_equal(const std::vector>& a, const std::vector>& b) {
+    if (a.size() != b.size()) {
+        return false;
+    }
+    for (size_t i = 0; i < a.size(); i++) {
+        if (!vec_tensor_buft_override_equal(a[i], b[i])) {
+            return false;
+        }
+    }
+    return true;
+}
+
 template  static std::string join(const std::vector & values, const std::string & delim) {
     std::ostringstream str;
     for (size_t i = 0; i < values.size(); i++) {
@@ -155,15 +195,57 @@ static std::string pair_str(const std::pair & p) {
     return buf;
 }
 
+static std::vector parse_int_range(const std::string & s) {
+    // first[-last[(+|*)step]]
+    std::regex range_regex(R"(^(\d+)(?:-(\d+)(?:([\+|\*])(\d+))?)?(?:,|$))");
+
+    std::smatch match;
+    std::string::const_iterator search_start(s.cbegin());
+    std::vector result;
+    while (std::regex_search(search_start, s.cend(), match, range_regex)) {
+        int  first = std::stoi(match[1]);
+        int  last  = match[2].matched ? std::stoi(match[2]) : first;
+        char op    = match[3].matched ? match[3].str()[0] : '+';
+        int  step  = match[4].matched ? std::stoi(match[4]) : 1;
+
+        for (int i = first; i <= last;) {
+            result.push_back(i);
+
+            int prev_i = i;
+
+            if (op == '+') {
+                i += step;
+            } else if (op == '*') {
+                i *= step;
+            } else {
+                throw std::invalid_argument("invalid range format");
+            }
+
+            if (i <= prev_i) {
+                throw std::invalid_argument("invalid range");
+            }
+        }
+        search_start = match.suffix().first;
+    }
+
+    if (search_start != s.cend()) {
+        throw std::invalid_argument("invalid range format");
+    }
+
+    return result;
+}
+
 struct cmd_params {
     std::vector         model;
     std::vector                 n_prompt;
     std::vector                 n_gen;
     std::vector> n_pg;
+    std::vector                 n_depth;
     std::vector                 n_batch;
     std::vector                 n_ubatch;
     std::vector           type_k;
     std::vector           type_v;
+    std::vector               defrag_thold;
     std::vector                 n_threads;
     std::vector         cpu_mask;
     std::vector                cpu_strict;
@@ -175,8 +257,10 @@ struct cmd_params {
     std::vector                no_kv_offload;
     std::vector                flash_attn;
     std::vector>  tensor_split;
+    std::vector> tensor_buft_overrides;
     std::vector                use_mmap;
     std::vector                embeddings;
+    std::vector                no_op_offload;
     ggml_numa_strategy               numa;
     int                              reps;
     ggml_sched_priority              prio;
@@ -192,10 +276,12 @@ static const cmd_params cmd_params_defaults = {
     /* n_prompt             */ { 512 },
     /* n_gen                */ { 128 },
     /* n_pg                 */ {},
+    /* n_depth              */ { 0 },
     /* n_batch              */ { 2048 },
     /* n_ubatch             */ { 512 },
     /* type_k               */ { GGML_TYPE_F16 },
     /* type_v               */ { GGML_TYPE_F16 },
+    /* defrag_thold         */ { -1.0f },
     /* n_threads            */ { cpu_get_num_math() },
     /* cpu_mask             */ { "0x0" },
     /* cpu_strict           */ { false },
@@ -207,8 +293,10 @@ static const cmd_params cmd_params_defaults = {
     /* no_kv_offload        */ { false },
     /* flash_attn           */ { false },
     /* tensor_split         */ { std::vector(llama_max_devices(), 0.0f) },
+    /* tensor_buft_overrides*/ { std::vector{ { nullptr, nullptr } } },
     /* use_mmap             */ { true },
     /* embeddings           */ { false },
+    /* no_op_offload        */ { false },
     /* numa                 */ GGML_NUMA_STRATEGY_DISABLED,
     /* reps                 */ 5,
     /* prio                 */ GGML_SCHED_PRIO_NORMAL,
@@ -224,12 +312,29 @@ static void print_usage(int /* argc */, char ** argv) {
     printf("\n");
     printf("options:\n");
     printf("  -h, --help\n");
+    printf("  --numa        numa mode (default: disabled)\n");
+    printf("  -r, --repetitions                      number of times to repeat each test (default: %d)\n",
+           cmd_params_defaults.reps);
+    printf("  --prio <-1|0|1|2|3>                          process/thread priority (default: %d)\n",
+           cmd_params_defaults.prio);
+    printf("  --delay <0...N> (seconds)                 delay between each test (default: %d)\n",
+           cmd_params_defaults.delay);
+    printf("  -o, --output       output format printed to stdout (default: %s)\n",
+           output_format_str(cmd_params_defaults.output_format));
+    printf("  -oe, --output-err  output format printed to stderr (default: %s)\n",
+           output_format_str(cmd_params_defaults.output_format_stderr));
+    printf("  -v, --verbose                             verbose output\n");
+    printf("  --progress                                print test progress indicators\n");
+    printf("\n");
+    printf("test parameters:\n");
     printf("  -m, --model                     (default: %s)\n", join(cmd_params_defaults.model, ",").c_str());
     printf("  -p, --n-prompt                         (default: %s)\n",
            join(cmd_params_defaults.n_prompt, ",").c_str());
     printf("  -n, --n-gen                            (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str());
     printf("  -pg                                (default: %s)\n",
            join(transform_to_str(cmd_params_defaults.n_pg, pair_str), ",").c_str());
+    printf("  -d, --n-depth                          (default: %s)\n",
+           join(cmd_params_defaults.n_depth, ",").c_str());
     printf("  -b, --batch-size                       (default: %s)\n",
            join(cmd_params_defaults.n_batch, ",").c_str());
     printf("  -ub, --ubatch-size                     (default: %s)\n",
@@ -238,6 +343,8 @@ static void print_usage(int /* argc */, char ** argv) {
            join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str());
     printf("  -ctv, --cache-type-v                   (default: %s)\n",
            join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str());
+    printf("  -dt, --defrag-thold                    (default: %s)\n",
+           join(cmd_params_defaults.defrag_thold, ",").c_str());
     printf("  -t, --threads                          (default: %s)\n",
            join(cmd_params_defaults.n_threads, ",").c_str());
     printf("  -C, --cpu-mask                   (default: %s)\n",
@@ -261,23 +368,17 @@ static void print_usage(int /* argc */, char ** argv) {
            join(cmd_params_defaults.flash_attn, ",").c_str());
     printf("  -mmp, --mmap <0|1>                        (default: %s)\n",
            join(cmd_params_defaults.use_mmap, ",").c_str());
-    printf("  --numa        (default: disabled)\n");
     printf("  -embd, --embeddings <0|1>                 (default: %s)\n",
            join(cmd_params_defaults.embeddings, ",").c_str());
     printf("  -ts, --tensor-split           (default: 0)\n");
-    printf("  -r, --repetitions                      (default: %d)\n", cmd_params_defaults.reps);
-    printf("  --prio <0|1|2|3>                          (default: %d)\n", cmd_params_defaults.prio);
-    printf("  --delay <0...N> (seconds)                 (default: %d)\n", cmd_params_defaults.delay);
-    printf("  -o, --output       (default: %s)\n",
-           output_format_str(cmd_params_defaults.output_format));
-    printf("  -oe, --output-err  (default: %s)\n",
-           output_format_str(cmd_params_defaults.output_format_stderr));
-    printf("  -v, --verbose                             (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0");
-    printf("  --progress                                (default: %s)\n", cmd_params_defaults.progress ? "1" : "0");
+    printf("  -ot --override-tensors =;...\n");
+    printf("                                            (default: disabled)\n");
+    printf("  -nopo, --no-op-offload <0|1>              (default: 0)\n");
     printf("\n");
     printf(
-        "Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter "
-        "multiple times.\n");
+        "Multiple values can be given for each parameter by separating them with ','\n"
+        "or by specifying the parameter multiple times. Ranges can be given as\n"
+        "'first-last' or 'first-last+step' or 'first-last*mult'.\n");
 }
 
 static ggml_type ggml_type_from_name(const std::string & s) {
@@ -331,179 +432,197 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
             std::replace(arg.begin(), arg.end(), '_', '-');
         }
 
-        if (arg == "-h" || arg == "--help") {
-            print_usage(argc, argv);
-            exit(0);
-        } else if (arg == "-m" || arg == "--model") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            auto p = string_split(argv[i], split_delim);
-            params.model.insert(params.model.end(), p.begin(), p.end());
-        } else if (arg == "-p" || arg == "--n-prompt") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            auto p = string_split(argv[i], split_delim);
-            params.n_prompt.insert(params.n_prompt.end(), p.begin(), p.end());
-        } else if (arg == "-n" || arg == "--n-gen") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            auto p = string_split(argv[i], split_delim);
-            params.n_gen.insert(params.n_gen.end(), p.begin(), p.end());
-        } else if (arg == "-pg") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            auto p = string_split(argv[i], ',');
-            if (p.size() != 2) {
-                invalid_param = true;
-                break;
-            }
-            params.n_pg.push_back({ std::stoi(p[0]), std::stoi(p[1]) });
-        } else if (arg == "-b" || arg == "--batch-size") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            auto p = string_split(argv[i], split_delim);
-            params.n_batch.insert(params.n_batch.end(), p.begin(), p.end());
-        } else if (arg == "-ub" || arg == "--ubatch-size") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            auto p = string_split(argv[i], split_delim);
-            params.n_ubatch.insert(params.n_ubatch.end(), p.begin(), p.end());
-        } else if (arg == "-ctk" || arg == "--cache-type-k") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            auto                   p = string_split(argv[i], split_delim);
-            std::vector types;
-            for (const auto & t : p) {
-                ggml_type gt = ggml_type_from_name(t);
-                if (gt == GGML_TYPE_COUNT) {
+        try {
+            if (arg == "-h" || arg == "--help") {
+                print_usage(argc, argv);
+                exit(0);
+            } else if (arg == "-m" || arg == "--model") {
+                if (++i >= argc) {
                     invalid_param = true;
                     break;
                 }
-                types.push_back(gt);
-            }
-            if (invalid_param) {
-                break;
-            }
-            params.type_k.insert(params.type_k.end(), types.begin(), types.end());
-        } else if (arg == "-ctv" || arg == "--cache-type-v") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            auto                   p = string_split(argv[i], split_delim);
-            std::vector types;
-            for (const auto & t : p) {
-                ggml_type gt = ggml_type_from_name(t);
-                if (gt == GGML_TYPE_COUNT) {
+                auto p = string_split(argv[i], split_delim);
+                params.model.insert(params.model.end(), p.begin(), p.end());
+            } else if (arg == "-p" || arg == "--n-prompt") {
+                if (++i >= argc) {
                     invalid_param = true;
                     break;
                 }
-                types.push_back(gt);
-            }
-            if (invalid_param) {
-                break;
-            }
-            params.type_v.insert(params.type_v.end(), types.begin(), types.end());
-        } else if (arg == "-t" || arg == "--threads") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            auto p = string_split(argv[i], split_delim);
-            params.n_threads.insert(params.n_threads.end(), p.begin(), p.end());
-        } else if (arg == "-C" || arg == "--cpu-mask") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            auto p = string_split(argv[i], split_delim);
-            params.cpu_mask.insert(params.cpu_mask.end(), p.begin(), p.end());
-        } else if (arg == "--cpu-strict") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            auto p = string_split(argv[i], split_delim);
-            params.cpu_strict.insert(params.cpu_strict.end(), p.begin(), p.end());
-        } else if (arg == "--poll") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            auto p = string_split(argv[i], split_delim);
-            params.poll.insert(params.poll.end(), p.begin(), p.end());
-        } else if (arg == "-ngl" || arg == "--n-gpu-layers") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            auto p = string_split(argv[i], split_delim);
-            params.n_gpu_layers.insert(params.n_gpu_layers.end(), p.begin(), p.end());
-        } else if (llama_supports_rpc() && (arg == "-rpc" || arg == "--rpc")) {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            params.rpc_servers.push_back(argv[i]);
-        } else if (arg == "-sm" || arg == "--split-mode") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            auto                          p = string_split(argv[i], split_delim);
-            std::vector modes;
-            for (const auto & m : p) {
-                llama_split_mode mode;
-                if (m == "none") {
-                    mode = LLAMA_SPLIT_MODE_NONE;
-                } else if (m == "layer") {
-                    mode = LLAMA_SPLIT_MODE_LAYER;
-                } else if (m == "row") {
-                    mode = LLAMA_SPLIT_MODE_ROW;
-                } else {
+                auto p = parse_int_range(argv[i]);
+                params.n_prompt.insert(params.n_prompt.end(), p.begin(), p.end());
+            } else if (arg == "-n" || arg == "--n-gen") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = parse_int_range(argv[i]);
+                params.n_gen.insert(params.n_gen.end(), p.begin(), p.end());
+            } else if (arg == "-pg") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = string_split(argv[i], ',');
+                if (p.size() != 2) {
+                    invalid_param = true;
+                    break;
+                }
+                params.n_pg.push_back({ std::stoi(p[0]), std::stoi(p[1]) });
+            } else if (arg == "-d" || arg == "--n-depth") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = parse_int_range(argv[i]);
+                params.n_depth.insert(params.n_depth.end(), p.begin(), p.end());
+            } else if (arg == "-b" || arg == "--batch-size") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = parse_int_range(argv[i]);
+                params.n_batch.insert(params.n_batch.end(), p.begin(), p.end());
+            } else if (arg == "-ub" || arg == "--ubatch-size") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = parse_int_range(argv[i]);
+                params.n_ubatch.insert(params.n_ubatch.end(), p.begin(), p.end());
+            } else if (arg == "-ctk" || arg == "--cache-type-k") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = string_split(argv[i], split_delim);
+
+                std::vector types;
+                for (const auto & t : p) {
+                    ggml_type gt = ggml_type_from_name(t);
+                    if (gt == GGML_TYPE_COUNT) {
+                        invalid_param = true;
+                        break;
+                    }
+                    types.push_back(gt);
+                }
+                if (invalid_param) {
+                    break;
+                }
+                params.type_k.insert(params.type_k.end(), types.begin(), types.end());
+            } else if (arg == "-ctv" || arg == "--cache-type-v") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = string_split(argv[i], split_delim);
+
+                std::vector types;
+                for (const auto & t : p) {
+                    ggml_type gt = ggml_type_from_name(t);
+                    if (gt == GGML_TYPE_COUNT) {
+                        invalid_param = true;
+                        break;
+                    }
+                    types.push_back(gt);
+                }
+                if (invalid_param) {
+                    break;
+                }
+                params.type_v.insert(params.type_v.end(), types.begin(), types.end());
+            } else if (arg == "-dt" || arg == "--defrag-thold") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = string_split(argv[i], split_delim);
+                params.defrag_thold.insert(params.defrag_thold.end(), p.begin(), p.end());
+            } else if (arg == "-t" || arg == "--threads") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = parse_int_range(argv[i]);
+                params.n_threads.insert(params.n_threads.end(), p.begin(), p.end());
+            } else if (arg == "-C" || arg == "--cpu-mask") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = string_split(argv[i], split_delim);
+                params.cpu_mask.insert(params.cpu_mask.end(), p.begin(), p.end());
+            } else if (arg == "--cpu-strict") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = string_split(argv[i], split_delim);
+                params.cpu_strict.insert(params.cpu_strict.end(), p.begin(), p.end());
+            } else if (arg == "--poll") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = parse_int_range(argv[i]);
+                params.poll.insert(params.poll.end(), p.begin(), p.end());
+            } else if (arg == "-ngl" || arg == "--n-gpu-layers") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = parse_int_range(argv[i]);
+                params.n_gpu_layers.insert(params.n_gpu_layers.end(), p.begin(), p.end());
+            } else if (llama_supports_rpc() && (arg == "-rpc" || arg == "--rpc")) {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                params.rpc_servers.push_back(argv[i]);
+            } else if (arg == "-sm" || arg == "--split-mode") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = string_split(argv[i], split_delim);
+
+                std::vector modes;
+                for (const auto & m : p) {
+                    llama_split_mode mode;
+                    if (m == "none") {
+                        mode = LLAMA_SPLIT_MODE_NONE;
+                    } else if (m == "layer") {
+                        mode = LLAMA_SPLIT_MODE_LAYER;
+                    } else if (m == "row") {
+                        mode = LLAMA_SPLIT_MODE_ROW;
+                    } else {
+                        invalid_param = true;
+                        break;
+                    }
+                    modes.push_back(mode);
+                }
+                if (invalid_param) {
+                    break;
+                }
+                params.split_mode.insert(params.split_mode.end(), modes.begin(), modes.end());
+            } else if (arg == "-mg" || arg == "--main-gpu") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                params.main_gpu = parse_int_range(argv[i]);
+            } else if (arg == "-nkvo" || arg == "--no-kv-offload") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = string_split(argv[i], split_delim);
+                params.no_kv_offload.insert(params.no_kv_offload.end(), p.begin(), p.end());
+            } else if (arg == "--numa") {
+                if (++i >= argc) {
                     invalid_param = true;
                     break;
                 }
-                modes.push_back(mode);
-            }
-            if (invalid_param) {
-                break;
-            }
-            params.split_mode.insert(params.split_mode.end(), modes.begin(), modes.end());
-        } else if (arg == "-mg" || arg == "--main-gpu") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            params.main_gpu = string_split(argv[i], split_delim);
-        } else if (arg == "-nkvo" || arg == "--no-kv-offload") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            auto p = string_split(argv[i], split_delim);
-            params.no_kv_offload.insert(params.no_kv_offload.end(), p.begin(), p.end());
-        } else if (arg == "--numa") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            } else {
                 std::string value(argv[i]);
-                /**/ if (value == "distribute" || value == "") {
+                if (value == "distribute" || value == "") {
                     params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE;
                 } else if (value == "isolate") {
                     params.numa = GGML_NUMA_STRATEGY_ISOLATE;
@@ -513,89 +632,183 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
                     invalid_param = true;
                     break;
                 }
-            }
-        } else if (arg == "-fa" || arg == "--flash-attn") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            auto p = string_split(argv[i], split_delim);
-            params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end());
-        } else if (arg == "-mmp" || arg == "--mmap") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            auto p = string_split(argv[i], split_delim);
-            params.use_mmap.insert(params.use_mmap.end(), p.begin(), p.end());
-        } else if (arg == "-embd" || arg == "--embeddings") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            auto p = string_split(argv[i], split_delim);
-            params.embeddings.insert(params.embeddings.end(), p.begin(), p.end());
-        } else if (arg == "-ts" || arg == "--tensor-split") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            for (auto ts : string_split(argv[i], split_delim)) {
-                // split string by ; and /
-                const std::regex           regex{ R"([;/]+)" };
-                std::sregex_token_iterator it{ ts.begin(), ts.end(), regex, -1 };
-                std::vector   split_arg{ it, {} };
-                GGML_ASSERT(split_arg.size() <= llama_max_devices());
+            } else if (arg == "-fa" || arg == "--flash-attn") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = string_split(argv[i], split_delim);
+                params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end());
+            } else if (arg == "-mmp" || arg == "--mmap") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = string_split(argv[i], split_delim);
+                params.use_mmap.insert(params.use_mmap.end(), p.begin(), p.end());
+            } else if (arg == "-embd" || arg == "--embeddings") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = string_split(argv[i], split_delim);
+                params.embeddings.insert(params.embeddings.end(), p.begin(), p.end());
+            } else if (arg == "-nopo" || arg == "--no-op-offload") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto p = string_split(argv[i], split_delim);
+                params.no_op_offload.insert(params.no_op_offload.end(), p.begin(), p.end());
+            } else if (arg == "-ts" || arg == "--tensor-split") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                for (auto ts : string_split(argv[i], split_delim)) {
+                    // split string by ; and /
+                    const std::regex           regex{ R"([;/]+)" };
+                    std::sregex_token_iterator it{ ts.begin(), ts.end(), regex, -1 };
+                    std::vector   split_arg{ it, {} };
+                    GGML_ASSERT(split_arg.size() <= llama_max_devices());
 
-                std::vector tensor_split(llama_max_devices());
-                for (size_t i = 0; i < llama_max_devices(); ++i) {
-                    if (i < split_arg.size()) {
-                        tensor_split[i] = std::stof(split_arg[i]);
-                    } else {
-                        tensor_split[i] = 0.0f;
+                    std::vector tensor_split(llama_max_devices());
+                    for (size_t i = 0; i < llama_max_devices(); ++i) {
+                        if (i < split_arg.size()) {
+                            tensor_split[i] = std::stof(split_arg[i]);
+                        } else {
+                            tensor_split[i] = 0.0f;
+                        }
+                    }
+                    params.tensor_split.push_back(tensor_split);
+                }
+            } else if (arg == "-ot" || arg == "--override-tensor") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                auto * value = argv[i];
+                /* static */ std::map buft_list;
+                if (buft_list.empty()) {
+                    // enumerate all the devices and add their buffer types to the list
+                    for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
+                        auto * dev = ggml_backend_dev_get(i);
+                        auto * buft = ggml_backend_dev_buffer_type(dev);
+                        if (buft) {
+                            buft_list[ggml_backend_buft_name(buft)] = buft;
+                        }
                     }
                 }
-                params.tensor_split.push_back(tensor_split);
-            }
-        } else if (arg == "-r" || arg == "--repetitions") {
-            if (++i >= argc) {
+                auto override_group_span_len = std::strcspn(value, ",");
+                bool last_group = false;
+                do {
+                    if (override_group_span_len == 0) {
+                        // Adds an empty override-tensors for an empty span
+                        params.tensor_buft_overrides.push_back({{}});
+                        if (value[override_group_span_len] == '\0') {
+                            value = &value[override_group_span_len];
+                            last_group = true;
+                        } else {
+                            value = &value[override_group_span_len + 1];
+                            override_group_span_len = std::strcspn(value, ",");
+                        }
+                        continue;
+                    }
+                    // Stamps null terminators into the argv
+                    // value for this option to avoid the
+                    // memory leak present in the implementation
+                    // over in arg.cpp. Acceptable because we
+                    // only parse these args once in this program.
+                    auto * override_group = value;
+                    if (value[override_group_span_len] == '\0') {
+                        value = &value[override_group_span_len];
+                        last_group = true;
+                    } else {
+                        value[override_group_span_len] = '\0';
+                        value = &value[override_group_span_len + 1];
+                    }
+                    std::vector group_tensor_buft_overrides{};
+                    auto override_span_len = std::strcspn(override_group, ";");
+                    while (override_span_len > 0) {
+                        auto * override = override_group;
+                        if (override_group[override_span_len] != '\0') {
+                            override_group[override_span_len] = '\0';
+                            override_group = &override_group[override_span_len + 1];
+                        } else {
+                            override_group = &override_group[override_span_len];
+                        }
+                        auto tensor_name_span_len = std::strcspn(override, "=");
+                        if (tensor_name_span_len >= override_span_len) {
+                            invalid_param = true;
+                            break;
+                        }
+                        override[tensor_name_span_len] = '\0';
+                        auto * tensor_name = override;
+                        auto * buffer_type = &override[tensor_name_span_len + 1];
+                        if (buft_list.find(buffer_type) == buft_list.end()) {
+                            printf("error: unrecognized buffer type '%s'\n", buffer_type);
+                            printf("Available buffer types:\n");
+                            for (const auto & it : buft_list) {
+                                printf("  %s\n", ggml_backend_buft_name(it.second));
+                            }
+                            invalid_param = true;
+                            break;
+                        }
+                        group_tensor_buft_overrides.push_back({tensor_name, buft_list.at(buffer_type)});
+                        override_span_len = std::strcspn(override_group, ";");
+                    }
+                    if (invalid_param) {
+                        break;
+                    }
+                    group_tensor_buft_overrides.push_back({nullptr,nullptr});
+                    params.tensor_buft_overrides.push_back(group_tensor_buft_overrides);
+                    override_group_span_len = std::strcspn(value, ",");
+                } while (!last_group);
+            } else if (arg == "-r" || arg == "--repetitions") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                params.reps = std::stoi(argv[i]);
+            } else if (arg == "--prio") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                params.prio = (enum ggml_sched_priority) std::stoi(argv[i]);
+            } else if (arg == "--delay") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                params.delay = std::stoi(argv[i]);
+            } else if (arg == "-o" || arg == "--output") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                invalid_param = !output_format_from_str(argv[i], params.output_format);
+            } else if (arg == "-oe" || arg == "--output-err") {
+                if (++i >= argc) {
+                    invalid_param = true;
+                    break;
+                }
+                invalid_param = !output_format_from_str(argv[i], params.output_format_stderr);
+            } else if (arg == "-v" || arg == "--verbose") {
+                params.verbose = true;
+            } else if (arg == "--progress") {
+                params.progress = true;
+            } else {
                 invalid_param = true;
                 break;
             }
-            params.reps = std::stoi(argv[i]);
-        } else if (arg == "--prio") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            params.prio = (enum ggml_sched_priority) std::stoi(argv[i]);
-        } else if (arg == "--delay") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            params.delay = std::stoi(argv[i]);
-        } else if (arg == "-o" || arg == "--output") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            invalid_param = !output_format_from_str(argv[i], params.output_format);
-        } else if (arg == "-oe" || arg == "--output-err") {
-            if (++i >= argc) {
-                invalid_param = true;
-                break;
-            }
-            invalid_param = !output_format_from_str(argv[i], params.output_format_stderr);
-        } else if (arg == "-v" || arg == "--verbose") {
-            params.verbose = true;
-        } else if (arg == "--progress") {
-            params.progress = true;
-        } else {
+        } catch (const std::exception & e) {
+            fprintf(stderr, "error: %s\n", e.what());
             invalid_param = true;
             break;
         }
     }
+
     if (invalid_param) {
         fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
         print_usage(argc, argv);
@@ -615,6 +828,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
     if (params.n_pg.empty()) {
         params.n_pg = cmd_params_defaults.n_pg;
     }
+    if (params.n_depth.empty()) {
+        params.n_depth = cmd_params_defaults.n_depth;
+    }
     if (params.n_batch.empty()) {
         params.n_batch = cmd_params_defaults.n_batch;
     }
@@ -627,6 +843,9 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
     if (params.type_v.empty()) {
         params.type_v = cmd_params_defaults.type_v;
     }
+    if (params.defrag_thold.empty()) {
+        params.defrag_thold = cmd_params_defaults.defrag_thold;
+    }
     if (params.n_gpu_layers.empty()) {
         params.n_gpu_layers = cmd_params_defaults.n_gpu_layers;
     }
@@ -648,12 +867,18 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
     if (params.tensor_split.empty()) {
         params.tensor_split = cmd_params_defaults.tensor_split;
     }
+    if (params.tensor_buft_overrides.empty()) {
+        params.tensor_buft_overrides = cmd_params_defaults.tensor_buft_overrides;
+    }
     if (params.use_mmap.empty()) {
         params.use_mmap = cmd_params_defaults.use_mmap;
     }
     if (params.embeddings.empty()) {
         params.embeddings = cmd_params_defaults.embeddings;
     }
+    if (params.no_op_offload.empty()) {
+        params.no_op_offload = cmd_params_defaults.no_op_offload;
+    }
     if (params.n_threads.empty()) {
         params.n_threads = cmd_params_defaults.n_threads;
     }
@@ -674,10 +899,12 @@ struct cmd_params_instance {
     std::string        model;
     int                n_prompt;
     int                n_gen;
+    int                n_depth;
     int                n_batch;
     int                n_ubatch;
     ggml_type          type_k;
     ggml_type          type_v;
+    float              defrag_thold;
     int                n_threads;
     std::string        cpu_mask;
     bool               cpu_strict;
@@ -689,8 +916,10 @@ struct cmd_params_instance {
     bool               no_kv_offload;
     bool               flash_attn;
     std::vector tensor_split;
+    std::vector tensor_buft_overrides;
     bool               use_mmap;
     bool               embeddings;
+    bool               no_op_offload;
 
     llama_model_params to_llama_mparams() const {
         llama_model_params mparams = llama_model_default_params();
@@ -733,26 +962,36 @@ struct cmd_params_instance {
         mparams.tensor_split = tensor_split.data();
         mparams.use_mmap     = use_mmap;
 
+        if (tensor_buft_overrides.empty()) {
+            mparams.tensor_buft_overrides = nullptr;
+        } else {
+            GGML_ASSERT(tensor_buft_overrides.back().pattern == nullptr && "Tensor buffer overrides not terminated with empty pattern");
+            mparams.tensor_buft_overrides = tensor_buft_overrides.data();
+        }
+
         return mparams;
     }
 
     bool equal_mparams(const cmd_params_instance & other) const {
         return model == other.model && n_gpu_layers == other.n_gpu_layers && rpc_servers_str == other.rpc_servers_str &&
                split_mode == other.split_mode && main_gpu == other.main_gpu && use_mmap == other.use_mmap &&
-               tensor_split == other.tensor_split;
+               tensor_split == other.tensor_split && vec_tensor_buft_override_equal(tensor_buft_overrides, other.tensor_buft_overrides);
     }
 
     llama_context_params to_llama_cparams() const {
         llama_context_params cparams = llama_context_default_params();
 
-        cparams.n_ctx       = n_prompt + n_gen;
-        cparams.n_batch     = n_batch;
-        cparams.n_ubatch    = n_ubatch;
-        cparams.type_k      = type_k;
-        cparams.type_v      = type_v;
-        cparams.offload_kqv = !no_kv_offload;
-        cparams.flash_attn  = flash_attn;
-        cparams.embeddings  = embeddings;
+        cparams.n_ctx        = n_prompt + n_gen + n_depth;
+        cparams.n_batch      = n_batch;
+        cparams.n_ubatch     = n_ubatch;
+        cparams.type_k       = type_k;
+        cparams.type_v       = type_v;
+        cparams.defrag_thold = defrag_thold;
+        cparams.offload_kqv  = !no_kv_offload;
+        cparams.flash_attn   = flash_attn;
+        cparams.embeddings   = embeddings;
+        cparams.op_offload   = !no_op_offload;
+        cparams.swa_full     = false;
 
         return cparams;
     }
@@ -769,17 +1008,21 @@ static std::vector get_cmd_params_instances(const cmd_param
     for (const auto & sm : params.split_mode)
     for (const auto & mg : params.main_gpu)
     for (const auto & ts : params.tensor_split)
+    for (const auto & ot : params.tensor_buft_overrides)
     for (const auto & mmp : params.use_mmap)
     for (const auto & embd : params.embeddings)
+    for (const auto & nopo : params.no_op_offload)
     for (const auto & nb : params.n_batch)
     for (const auto & nub : params.n_ubatch)
     for (const auto & tk : params.type_k)
     for (const auto & tv : params.type_v)
+    for (const auto & defrag_thold : params.defrag_thold)
     for (const auto & nkvo : params.no_kv_offload)
     for (const auto & fa : params.flash_attn)
     for (const auto & nt : params.n_threads)
     for (const auto & cm : params.cpu_mask)
     for (const auto & cs : params.cpu_strict)
+    for (const auto & nd : params.n_depth)
     for (const auto & pl : params.poll) {
         for (const auto & n_prompt : params.n_prompt) {
             if (n_prompt == 0) {
@@ -789,10 +1032,12 @@ static std::vector get_cmd_params_instances(const cmd_param
                 /* .model        = */ m,
                 /* .n_prompt     = */ n_prompt,
                 /* .n_gen        = */ 0,
+                /* .n_depth      = */ nd,
                 /* .n_batch      = */ nb,
                 /* .n_ubatch     = */ nub,
                 /* .type_k       = */ tk,
                 /* .type_v       = */ tv,
+                /* .defrag_thold = */ defrag_thold,
                 /* .n_threads    = */ nt,
                 /* .cpu_mask     = */ cm,
                 /* .cpu_strict   = */ cs,
@@ -804,8 +1049,10 @@ static std::vector get_cmd_params_instances(const cmd_param
                 /* .no_kv_offload= */ nkvo,
                 /* .flash_attn   = */ fa,
                 /* .tensor_split = */ ts,
+                /* .tensor_buft_overrides = */ ot,
                 /* .use_mmap     = */ mmp,
                 /* .embeddings   = */ embd,
+                /* .no_op_offload= */ nopo,
             };
             instances.push_back(instance);
         }
@@ -818,10 +1065,12 @@ static std::vector get_cmd_params_instances(const cmd_param
                 /* .model        = */ m,
                 /* .n_prompt     = */ 0,
                 /* .n_gen        = */ n_gen,
+                /* .n_depth      = */ nd,
                 /* .n_batch      = */ nb,
                 /* .n_ubatch     = */ nub,
                 /* .type_k       = */ tk,
                 /* .type_v       = */ tv,
+                /* .defrag_thold = */ defrag_thold,
                 /* .n_threads    = */ nt,
                 /* .cpu_mask     = */ cm,
                 /* .cpu_strict   = */ cs,
@@ -833,8 +1082,10 @@ static std::vector get_cmd_params_instances(const cmd_param
                 /* .no_kv_offload= */ nkvo,
                 /* .flash_attn   = */ fa,
                 /* .tensor_split = */ ts,
+                /* .tensor_buft_overrides = */ ot,
                 /* .use_mmap     = */ mmp,
                 /* .embeddings   = */ embd,
+                /* .no_op_offload= */ nopo,
             };
             instances.push_back(instance);
         }
@@ -847,10 +1098,12 @@ static std::vector get_cmd_params_instances(const cmd_param
                 /* .model        = */ m,
                 /* .n_prompt     = */ n_pg.first,
                 /* .n_gen        = */ n_pg.second,
+                /* .n_depth      = */ nd,
                 /* .n_batch      = */ nb,
                 /* .n_ubatch     = */ nub,
                 /* .type_k       = */ tk,
                 /* .type_v       = */ tv,
+                /* .defrag_thold = */ defrag_thold,
                 /* .n_threads    = */ nt,
                 /* .cpu_mask     = */ cm,
                 /* .cpu_strict   = */ cs,
@@ -862,8 +1115,10 @@ static std::vector get_cmd_params_instances(const cmd_param
                 /* .no_kv_offload= */ nkvo,
                 /* .flash_attn   = */ fa,
                 /* .tensor_split = */ ts,
+                /* .tensor_buft_overrides = */ ot,
                 /* .use_mmap     = */ mmp,
                 /* .embeddings   = */ embd,
+                /* .no_op_offload= */ nopo,
             };
             instances.push_back(instance);
         }
@@ -890,16 +1145,20 @@ struct test {
     int                      poll;
     ggml_type                type_k;
     ggml_type                type_v;
+    float                    defrag_thold;
     int                      n_gpu_layers;
     llama_split_mode         split_mode;
     int                      main_gpu;
     bool                     no_kv_offload;
     bool                     flash_attn;
     std::vector       tensor_split;
+    std::vector tensor_buft_overrides;
     bool                     use_mmap;
     bool                     embeddings;
+    bool                     no_op_offload;
     int                      n_prompt;
     int                      n_gen;
+    int                      n_depth;
     std::string              test_time;
     std::vector    samples_ns;
 
@@ -921,16 +1180,20 @@ struct test {
         poll           = inst.poll;
         type_k         = inst.type_k;
         type_v         = inst.type_v;
+        defrag_thold   = inst.defrag_thold;
         n_gpu_layers   = inst.n_gpu_layers;
         split_mode     = inst.split_mode;
         main_gpu       = inst.main_gpu;
         no_kv_offload  = inst.no_kv_offload;
         flash_attn     = inst.flash_attn;
         tensor_split   = inst.tensor_split;
+        tensor_buft_overrides = inst.tensor_buft_overrides;
         use_mmap       = inst.use_mmap;
         embeddings     = inst.embeddings;
+        no_op_offload  = inst.no_op_offload;
         n_prompt       = inst.n_prompt;
         n_gen          = inst.n_gen;
+        n_depth        = inst.n_depth;
         // RFC 3339 date-time format
         time_t t       = time(NULL);
         std::strftime(buf, sizeof(buf), "%FT%TZ", gmtime(&t));
@@ -972,9 +1235,10 @@ struct test {
             "build_commit", "build_number", "cpu_info",       "gpu_info",   "backends",     "model_filename",
             "model_type",   "model_size",   "model_n_params", "n_batch",    "n_ubatch",     "n_threads",
             "cpu_mask",     "cpu_strict",   "poll",           "type_k",     "type_v",       "n_gpu_layers",
-            "split_mode",   "main_gpu",     "no_kv_offload",  "flash_attn", "tensor_split", "use_mmap",
-            "embeddings",   "n_prompt",     "n_gen",          "test_time",  "avg_ns",       "stddev_ns",
-            "avg_ts",       "stddev_ts",
+            "split_mode",   "main_gpu",     "no_kv_offload",  "flash_attn", "tensor_split", "tensor_buft_overrides",
+            "defrag_thold",
+            "use_mmap",     "embeddings",   "no_op_offload",   "n_prompt",       "n_gen",      "n_depth",      "test_time",
+            "avg_ns",       "stddev_ns",    "avg_ts",         "stddev_ts",
         };
         return fields;
     }
@@ -984,15 +1248,15 @@ struct test {
     static field_type get_field_type(const std::string & field) {
         if (field == "build_number" || field == "n_batch" || field == "n_ubatch" || field == "n_threads" ||
             field == "poll" || field == "model_size" || field == "model_n_params" || field == "n_gpu_layers" ||
-            field == "main_gpu" || field == "n_prompt" || field == "n_gen" || field == "avg_ns" ||
-            field == "stddev_ns") {
+            field == "main_gpu" || field == "n_prompt" || field == "n_gen" || field == "n_depth" ||
+            field == "avg_ns" || field == "stddev_ns" || field == "no_op_offload") {
             return INT;
         }
         if (field == "f16_kv" || field == "no_kv_offload" || field == "cpu_strict" || field == "flash_attn" ||
             field == "use_mmap" || field == "embeddings") {
             return BOOL;
         }
-        if (field == "avg_ts" || field == "stddev_ts") {
+        if (field == "avg_ts" || field == "stddev_ts" || field == "defrag_thold") {
             return FLOAT;
         }
         return STRING;
@@ -1000,6 +1264,7 @@ struct test {
 
     std::vector get_values() const {
         std::string tensor_split_str;
+        std::string tensor_buft_overrides_str;
         int         max_nonzero = 0;
         for (size_t i = 0; i < llama_max_devices(); i++) {
             if (tensor_split[i] > 0) {
@@ -1014,6 +1279,26 @@ struct test {
                 tensor_split_str += "/";
             }
         }
+        if (tensor_buft_overrides.size() == 1) {
+            // Last element of tensor_buft_overrides is always a null pattern
+            // so if it is only one element long, it must be a null pattern.
+            GGML_ASSERT(tensor_buft_overrides[0].pattern == nullptr);
+            tensor_buft_overrides_str += "none";
+        } else {
+            for (size_t i = 0; i < tensor_buft_overrides.size()-1; i++) {
+                // Last element of tensor_buft_overrides is always a null pattern
+                if (tensor_buft_overrides[i].pattern == nullptr) {
+                    tensor_buft_overrides_str += "none";
+                } else {
+                    tensor_buft_overrides_str += tensor_buft_overrides[i].pattern;
+                    tensor_buft_overrides_str += "=";
+                    tensor_buft_overrides_str += ggml_backend_buft_name(tensor_buft_overrides[i].buft);
+                }
+                if (i + 2 < tensor_buft_overrides.size()) {
+                    tensor_buft_overrides_str += ";";
+                }
+            }
+        }
         std::vector values = { build_commit,
                                             std::to_string(build_number),
                                             cpu_info,
@@ -1037,10 +1322,14 @@ struct test {
                                             std::to_string(no_kv_offload),
                                             std::to_string(flash_attn),
                                             tensor_split_str,
+                                            tensor_buft_overrides_str,
+                                            std::to_string(defrag_thold),
                                             std::to_string(use_mmap),
                                             std::to_string(embeddings),
+                                            std::to_string(no_op_offload),
                                             std::to_string(n_prompt),
                                             std::to_string(n_gen),
+                                            std::to_string(n_depth),
                                             test_time,
                                             std::to_string(avg_ns()),
                                             std::to_string(stdev_ns()),
@@ -1218,7 +1507,10 @@ struct markdown_printer : public printer {
             return 4;
         }
         if (field == "test") {
-            return 13;
+            return 15;
+        }
+        if (field == "no_op_offload") {
+            return 4;
         }
 
         int width = std::max((int) field.length(), 10);
@@ -1251,9 +1543,15 @@ struct markdown_printer : public printer {
         if (field == "embeddings") {
             return "embd";
         }
+        if (field == "no_op_offload") {
+            return "nopo";
+        }
         if (field == "tensor_split") {
             return "ts";
         }
+        if (field == "tensor_buft_overrides") {
+            return "ot";
+        }
         return field;
     }
 
@@ -1292,6 +1590,9 @@ struct markdown_printer : public printer {
         if (params.type_v.size() > 1 || params.type_v != cmd_params_defaults.type_v) {
             fields.emplace_back("type_v");
         }
+        if (params.defrag_thold.size() > 1 || params.defrag_thold != cmd_params_defaults.defrag_thold) {
+            fields.emplace_back("defrag_thold");
+        }
         if (params.main_gpu.size() > 1 || params.main_gpu != cmd_params_defaults.main_gpu) {
             fields.emplace_back("main_gpu");
         }
@@ -1307,12 +1608,18 @@ struct markdown_printer : public printer {
         if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
             fields.emplace_back("tensor_split");
         }
+        if (params.tensor_buft_overrides.size() > 1 || !vec_vec_tensor_buft_override_equal(params.tensor_buft_overrides, cmd_params_defaults.tensor_buft_overrides)) {
+            fields.emplace_back("tensor_buft_overrides");
+        }
         if (params.use_mmap.size() > 1 || params.use_mmap != cmd_params_defaults.use_mmap) {
             fields.emplace_back("use_mmap");
         }
         if (params.embeddings.size() > 1 || params.embeddings != cmd_params_defaults.embeddings) {
             fields.emplace_back("embeddings");
         }
+        if (params.no_op_offload.size() > 1 || params.no_op_offload != cmd_params_defaults.no_op_offload) {
+            fields.emplace_back("no_op_offload");
+        }
         fields.emplace_back("test");
         fields.emplace_back("t/s");
 
@@ -1362,6 +1669,10 @@ struct markdown_printer : public printer {
                 } else {
                     snprintf(buf, sizeof(buf), "pp%d+tg%d", t.n_prompt, t.n_gen);
                 }
+                if (t.n_depth > 0) {
+                    int len = strlen(buf);
+                    snprintf(buf + len, sizeof(buf) - len, " @ d%d", t.n_depth);
+                }
                 value = buf;
             } else if (field == "t/s") {
                 snprintf(buf, sizeof(buf), "%.2f ± %.2f", t.avg_ts(), t.stdev_ts());
@@ -1427,7 +1738,7 @@ struct sql_printer : public printer {
     }
 };
 
-static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
+static bool test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
     llama_set_n_threads(ctx, n_threads, n_threads);
 
     const llama_model * model   = llama_get_model(ctx);
@@ -1444,14 +1755,19 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
         for (int i = 1; i < n_tokens; i++) {
             tokens[i] = std::rand() % n_vocab;
         }
-        llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens));
+        int res = llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens));
+        if (res != 0) {
+            fprintf(stderr, "%s: failed to decode prompt batch, res = %d\n", __func__, res);
+            return false;
+        }
         n_processed += n_tokens;
     }
 
     llama_synchronize(ctx);
+    return true;
 }
 
-static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
+static bool test_gen(llama_context * ctx, int n_gen, int n_threads) {
     llama_set_n_threads(ctx, n_threads, n_threads);
 
     const llama_model * model   = llama_get_model(ctx);
@@ -1461,10 +1777,15 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
     llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;
 
     for (int i = 0; i < n_gen; i++) {
-        llama_decode(ctx, llama_batch_get_one(&token, 1));
+        int res = llama_decode(ctx, llama_batch_get_one(&token, 1));
+        if (res != 0) {
+            fprintf(stderr, "%s: failed to decode generation batch, res = %d\n", __func__, res);
+            return false;
+        }
         llama_synchronize(ctx);
         token = std::rand() % n_vocab;
     }
+    return true;
 }
 
 static void llama_null_log_callback(enum ggml_log_level level, const char * text, void * user_data) {
@@ -1507,10 +1828,11 @@ int main(int argc, char ** argv) {
     fprintf(stderr, "warning: sanitizer enabled, performance may be affected\n");
 #endif
 
-    cmd_params params = parse_cmd_params(argc, argv);
-
     // initialize backends
     ggml_backend_load_all();
+
+    cmd_params params = parse_cmd_params(argc, argv);
+
     auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
     if (!cpu_dev) {
         fprintf(stderr, "%s: error: CPU backend is not loaded\n", __func__);
@@ -1578,7 +1900,7 @@ int main(int argc, char ** argv) {
 
         test t(inst, lmodel, ctx);
 
-        llama_kv_self_clear(ctx);
+        llama_memory_clear(llama_get_memory(ctx), false);
 
         // cool off before the test
         if (params.delay) {
@@ -1608,17 +1930,37 @@ int main(int argc, char ** argv) {
                 fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup prompt run\n", params_idx, params_count);
             }
             //test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
-            test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
+            bool res = test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
+            if (!res) {
+                fprintf(stderr, "%s: error: failed to run prompt warmup\n", __func__);
+                exit(1);
+            }
         }
         if (t.n_gen > 0) {
             if (params.progress) {
                 fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup generation run\n", params_idx, params_count);
             }
-            test_gen(ctx, 1, t.n_threads);
+            bool res = test_gen(ctx, 1, t.n_threads);
+            if (!res) {
+                fprintf(stderr, "%s: error: failed to run gen warmup\n", __func__);
+                exit(1);
+            }
         }
 
         for (int i = 0; i < params.reps; i++) {
-            llama_kv_self_clear(ctx);
+            llama_memory_clear(llama_get_memory(ctx), false);
+
+            if (t.n_depth > 0) {
+                if (params.progress) {
+                    fprintf(stderr, "llama-bench: benchmark %d/%zu: depth run %d/%d\n", params_idx, params_count,
+                            i + 1, params.reps);
+                }
+                bool res = test_prompt(ctx, t.n_depth, t.n_batch, t.n_threads);
+                if (!res) {
+                    fprintf(stderr, "%s: error: failed to run depth\n", __func__);
+                    exit(1);
+                }
+            }
 
             uint64_t t_start = get_time_ns();
 
@@ -1627,14 +1969,22 @@ int main(int argc, char ** argv) {
                     fprintf(stderr, "llama-bench: benchmark %d/%zu: prompt run %d/%d\n", params_idx, params_count,
                             i + 1, params.reps);
                 }
-                test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
+                bool res = test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
+                if (!res) {
+                    fprintf(stderr, "%s: error: failed to run prompt\n", __func__);
+                    exit(1);
+                }
             }
             if (t.n_gen > 0) {
                 if (params.progress) {
                     fprintf(stderr, "llama-bench: benchmark %d/%zu: generation run %d/%d\n", params_idx, params_count,
                             i + 1, params.reps);
                 }
-                test_gen(ctx, t.n_gen, t.n_threads);
+                bool res = test_gen(ctx, t.n_gen, t.n_threads);
+                if (!res) {
+                    fprintf(stderr, "%s: error: failed to run gen\n", __func__);
+                    exit(1);
+                }
             }
 
             uint64_t t_ns = get_time_ns() - t_start;
diff --git a/examples/main/CMakeLists.txt b/tools/main/CMakeLists.txt
similarity index 100%
rename from examples/main/CMakeLists.txt
rename to tools/main/CMakeLists.txt
diff --git a/examples/main/README.md b/tools/main/README.md
similarity index 99%
rename from examples/main/README.md
rename to tools/main/README.md
index e4b3590b5..4f16ad6b2 100644
--- a/examples/main/README.md
+++ b/tools/main/README.md
@@ -1,4 +1,4 @@
-# llama.cpp/examples/main
+# llama.cpp/tools/main
 
 This example program allows you to use various LLaMA language models easily and efficiently. It is specifically designed to work with the [llama.cpp](https://github.com/ggml-org/llama.cpp) project, which provides a plain C/C++ implementation with optional 4-bit quantization support for faster, lower memory inference, and is optimized for desktop CPUs. This program can be used to perform various inference tasks with LLaMA models, including generating text based on user-provided prompts and chat-like interactions with reverse prompts.
 
diff --git a/examples/main/main.cpp b/tools/main/main.cpp
similarity index 96%
rename from examples/main/main.cpp
rename to tools/main/main.cpp
index fd7410a64..19b247b0d 100644
--- a/examples/main/main.cpp
+++ b/tools/main/main.cpp
@@ -99,14 +99,6 @@ int main(int argc, char ** argv) {
     console::init(params.simple_io, params.use_color);
     atexit([]() { console::cleanup(); });
 
-    if (params.logits_all) {
-        LOG_ERR("************\n");
-        LOG_ERR("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__);
-        LOG_ERR("************\n\n");
-
-        return 0;
-    }
-
     if (params.embedding) {
         LOG_ERR("************\n");
         LOG_ERR("%s: please use the 'embedding' tool for embedding calculations\n", __func__);
@@ -155,12 +147,19 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
+    auto * mem = llama_get_memory(ctx);
+
     const llama_vocab * vocab = llama_model_get_vocab(model);
     auto chat_templates = common_chat_templates_init(model, params.chat_template);
 
     LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
 
-    auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU));
+    auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
+    if (!cpu_dev) {
+        LOG_ERR("%s: no CPU backend found\n", __func__);
+        return 1;
+    }
+    auto * reg = ggml_backend_dev_backend_reg(cpu_dev);
     auto * ggml_threadpool_new_fn = (decltype(ggml_threadpool_new) *) ggml_backend_reg_get_proc_address(reg, "ggml_threadpool_new");
     auto * ggml_threadpool_free_fn = (decltype(ggml_threadpool_free) *) ggml_backend_reg_get_proc_address(reg, "ggml_threadpool_free");
 
@@ -354,7 +353,7 @@ int main(int argc, char ** argv) {
         }
 
         // remove any "future" tokens that we might have inherited from the previous session
-        llama_kv_self_seq_rm(ctx, -1, n_matching_session_tokens, -1);
+        llama_memory_seq_rm(mem, -1, n_matching_session_tokens, -1);
     }
 
     LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",
@@ -602,8 +601,8 @@ int main(int argc, char ** argv) {
                     LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
                             n_past, n_left, n_ctx, params.n_keep, n_discard);
 
-                    llama_kv_self_seq_rm (ctx, 0, params.n_keep            , params.n_keep + n_discard);
-                    llama_kv_self_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
+                    llama_memory_seq_rm (mem, 0, params.n_keep            , params.n_keep + n_discard);
+                    llama_memory_seq_add(mem, 0, params.n_keep + n_discard, n_past, -n_discard);
 
                     n_past -= n_discard;
 
@@ -626,9 +625,9 @@ int main(int argc, char ** argv) {
                     LOG_DBG("div:   [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
                     LOG_DBG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
 
-                    llama_kv_self_seq_add(ctx, 0, ga_i,                n_past,              ib*bd);
-                    llama_kv_self_seq_div(ctx, 0, ga_i + ib*bd,        ga_i + ib*bd + ga_w, ga_n);
-                    llama_kv_self_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd,      dd);
+                    llama_memory_seq_add(mem, 0, ga_i,                n_past,              ib*bd);
+                    llama_memory_seq_div(mem, 0, ga_i + ib*bd,        ga_i + ib*bd + ga_w, ga_n);
+                    llama_memory_seq_add(mem, 0, ga_i + ib*bd + ga_w, n_past + ib*bd,      dd);
 
                     n_past -= bd;
 
@@ -865,9 +864,22 @@ int main(int argc, char ** argv) {
                 console::set_display(console::reset);
                 display = true;
 
-                // Add tokens to embd only if the input buffer is non-empty
-                // Entering a empty line lets the user pass control back
-                if (buffer.length() > 1) {
+                if (buffer.empty()) { // Ctrl+D on empty line exits
+                    LOG("EOF by user\n");
+                    break;
+                }
+
+                if (buffer.back() == '\n') {
+                    // Implement #587:
+                    // If the user wants the text to end in a newline,
+                    // this should be accomplished by explicitly adding a newline by using \ followed by return,
+                    // then returning control by pressing return again.
+                    buffer.pop_back();
+                }
+
+                if (buffer.empty()) { // Enter key on empty line lets the user pass control back
+                    LOG_DBG("empty line, passing control back\n");
+                } else { // Add tokens to embd only if the input buffer is non-empty
                     // append input suffix if any
                     if (!params.input_suffix.empty() && !params.conversation_mode) {
                         LOG_DBG("appending input suffix: '%s'\n", params.input_suffix.c_str());
@@ -915,8 +927,6 @@ int main(int argc, char ** argv) {
 
                     n_remain -= line_inp.size();
                     LOG_DBG("n_remain: %d\n", n_remain);
-                } else {
-                    LOG_DBG("empty line, passing control back\n");
                 }
 
                 input_echo = false; // do not echo this again
diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt
new file mode 100644
index 000000000..4baa15b96
--- /dev/null
+++ b/tools/mtmd/CMakeLists.txt
@@ -0,0 +1,60 @@
+# mtmd
+
+find_package(Threads REQUIRED)
+
+add_library(mtmd
+            mtmd.cpp
+            mtmd-audio.cpp
+            mtmd.h
+            clip.cpp
+            clip.h
+            clip-impl.h
+            mtmd-helper.cpp
+            mtmd-helper.h
+            )
+
+target_link_libraries     (mtmd PUBLIC ggml llama)
+target_link_libraries     (mtmd PRIVATE Threads::Threads)
+target_include_directories(mtmd PUBLIC  .)
+target_include_directories(mtmd PRIVATE ../..)
+target_include_directories(mtmd PRIVATE ../../vendor)
+target_compile_features   (mtmd PRIVATE cxx_std_17)
+
+if (BUILD_SHARED_LIBS)
+    set_target_properties     (mtmd PROPERTIES POSITION_INDEPENDENT_CODE ON)
+    target_compile_definitions(mtmd PRIVATE LLAMA_BUILD)
+    target_compile_definitions(mtmd PUBLIC  LLAMA_SHARED)
+endif()
+
+set(MTMD_PUBLIC_HEADERS
+    ${CMAKE_CURRENT_SOURCE_DIR}/mtmd.h
+    ${CMAKE_CURRENT_SOURCE_DIR}/mtmd-helper.h
+    )
+
+set_target_properties(mtmd
+    PROPERTIES
+    PUBLIC_HEADER "${MTMD_PUBLIC_HEADERS}")
+
+install(TARGETS mtmd LIBRARY PUBLIC_HEADER)
+
+if (NOT MSVC)
+    # for stb_image.h and miniaudio.h
+    target_compile_options(mtmd PRIVATE -Wno-cast-qual)
+endif()
+
+if (TARGET BUILD_INFO)
+    add_dependencies(mtmd        BUILD_INFO)
+    add_dependencies(mtmd-helper BUILD_INFO)
+endif()
+
+add_executable(llama-llava-cli    deprecation-warning.cpp)
+add_executable(llama-gemma3-cli   deprecation-warning.cpp)
+add_executable(llama-minicpmv-cli deprecation-warning.cpp)
+add_executable(llama-qwen2vl-cli  deprecation-warning.cpp)
+
+set(TARGET llama-mtmd-cli)
+add_executable         (${TARGET} mtmd-cli.cpp)
+set_target_properties  (${TARGET} PROPERTIES OUTPUT_NAME llama-mtmd-cli)
+install                (TARGETS ${TARGET} RUNTIME)
+target_link_libraries  (${TARGET} PRIVATE common mtmd Threads::Threads)
+target_compile_features(${TARGET} PRIVATE cxx_std_17)
diff --git a/tools/mtmd/README.md b/tools/mtmd/README.md
new file mode 100644
index 000000000..ef31d1957
--- /dev/null
+++ b/tools/mtmd/README.md
@@ -0,0 +1,63 @@
+# Multimodal Support in llama.cpp
+
+This directory provides multimodal capabilities for `llama.cpp`. Initially intended as a showcase for running LLaVA models, its scope has expanded significantly over time to include various other vision-capable models. As a result, LLaVA is no longer the only multimodal architecture supported.
+
+> [!IMPORTANT]
+>
+> Multimodal support can be viewed as a sub-project within `llama.cpp`. It is under **very heavy development**, and **breaking changes are expected**.
+
+The naming and structure related to multimodal support have evolved, which might cause some confusion. Here's a brief timeline to clarify:
+
+- [#3436](https://github.com/ggml-org/llama.cpp/pull/3436): Initial support for LLaVA 1.5 was added, introducing `llava.cpp` and `clip.cpp`. The `llava-cli` binary was created for model interaction.
+- [#4954](https://github.com/ggml-org/llama.cpp/pull/4954): Support for MobileVLM was added, becoming the second vision model supported. This built upon the existing `llava.cpp`, `clip.cpp`, and `llava-cli` infrastructure.
+- **Expansion & Fragmentation:** Many new models were subsequently added (e.g., [#7599](https://github.com/ggml-org/llama.cpp/pull/7599), [#10361](https://github.com/ggml-org/llama.cpp/pull/10361), [#12344](https://github.com/ggml-org/llama.cpp/pull/12344), and others). However, `llava-cli` lacked support for the increasingly complex chat templates required by these models. This led to the creation of model-specific binaries like `qwen2vl-cli`, `minicpmv-cli`, and `gemma3-cli`. While functional, this proliferation of command-line tools became confusing for users.
+- [#12849](https://github.com/ggml-org/llama.cpp/pull/12849): `libmtmd` was introduced as a replacement for `llava.cpp`. Its goals include providing a single, unified command-line interface, improving the user/developer experience (UX/DX), and supporting both audio and image inputs.
+- [#13012](https://github.com/ggml-org/llama.cpp/pull/13012): `mtmd-cli` was added, consolidating the various model-specific CLIs into a single tool powered by `libmtmd`.
+
+## Pre-quantized models
+
+See the list of pre-quantized model [here](../../docs/multimodal.md)
+
+## How it works and what is `mmproj`?
+
+Multimodal support in `llama.cpp` works by encoding images into embeddings using a separate model component, and then feeding these embeddings into the language model.
+
+This approach keeps the multimodal components distinct from the core `libllama` library. Separating these allows for faster, independent development cycles. While many modern vision models are based on Vision Transformers (ViTs), their specific pre-processing and projection steps can vary significantly. Integrating this diverse complexity directly into `libllama` is currently challenging.
+
+Consequently, running a multimodal model typically requires two GGUF files:
+1.  The standard language model file.
+2.  A corresponding **multimodal projector (`mmproj`)** file, which handles the image encoding and projection.
+
+## What is `libmtmd`?
+
+As outlined in the history, `libmtmd` is the modern library designed to replace the original `llava.cpp` implementation for handling multimodal inputs.
+
+Built upon `clip.cpp` (similar to `llava.cpp`), `libmtmd` offers several advantages:
+- **Unified Interface:** Aims to consolidate interaction for various multimodal models.
+- **Improved UX/DX:** Features a more intuitive API, inspired by the `Processor` class in the Hugging Face `transformers` library.
+- **Flexibility:** Designed to support multiple input types (text, audio, images) while respecting the wide variety of chat templates used by different models.
+
+## How to obtain `mmproj`
+
+Multimodal projector (`mmproj`) files are specific to each model architecture.
+
+For the following models, you can use `convert_hf_to_gguf.py` with `--mmproj` flag to get the `mmproj` file:
+- [Gemma 3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d) ; See the guide [here](../../docs/multimodal/gemma3.md) - Note: 1B variant does not have vision support
+- SmolVLM (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
+- SmolVLM2 (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
+- [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint
+- Qwen 2 VL and Qwen 2.5 VL (from [Qwen](https://huggingface.co/Qwen))
+- [Mistral Small 3.1 24B](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)
+- InternVL 2.5 and InternVL 3 from [OpenGVLab](https://huggingface.co/OpenGVLab) (note: we don't support conversion of `InternVL3-*-hf` model, only non-HF version is supported ; `InternLM2Model` **text** model is not supported)
+
+For older models, please refer to the relevant guide for instructions on how to obtain or create them:
+
+NOTE: conversion scripts are located under `tools/mtmd/legacy-models`
+
+- [LLaVA](../../docs/multimodal/llava.md)
+- [MobileVLM](../../docs/multimodal/MobileVLM.md)
+- [GLM-Edge](../../docs/multimodal/glmedge.md)
+- [MiniCPM-V 2.5](../../docs/multimodal/minicpmv2.5.md)
+- [MiniCPM-V 2.6](../../docs/multimodal/minicpmv2.6.md)
+- [MiniCPM-o 2.6](../../docs/multimodal/minicpmo2.6.md)
+- [IBM Granite Vision](../../docs/multimodal/granitevision.md)
diff --git a/examples/llava/clip-impl.h b/tools/mtmd/clip-impl.h
similarity index 65%
rename from examples/llava/clip-impl.h
rename to tools/mtmd/clip-impl.h
index 4d7340a56..62c936ed0 100644
--- a/examples/llava/clip-impl.h
+++ b/tools/mtmd/clip-impl.h
@@ -2,10 +2,9 @@
 #include "gguf.h"
 #include "clip.h"
 
-#include "clip.h"
-
 #include 
 #include 
+#include 
 #include 
 #include 
 #include 
@@ -17,40 +16,44 @@
 #define KEY_FTYPE               "general.file_type"
 #define KEY_NAME                "general.name"
 #define KEY_DESCRIPTION         "general.description"
-#define KEY_HAS_TEXT_ENC        "clip.has_text_encoder"
-#define KEY_HAS_VIS_ENC         "clip.has_vision_encoder"
-#define KEY_HAS_LLAVA_PROJ      "clip.has_llava_projector"
-#define KEY_HAS_MINICPMV_PROJ   "clip.has_minicpmv_projector"
-#define KEY_HAS_GLM_PROJ        "clip.has_glm_projector"
-#define KEY_MINICPMV_VERSION    "clip.minicpmv_version"
-#define KEY_HAS_QWEN2VL_MERGER  "clip.has_qwen2vl_merger"
+#define KEY_PROJ_TYPE           "clip.projector_type"
+#define KEY_HAS_AUDIO_ENC       "clip.has_audio_encoder"
+#define KEY_HAS_VISION_ENC      "clip.has_vision_encoder"
 #define KEY_USE_GELU            "clip.use_gelu"
 #define KEY_USE_SILU            "clip.use_silu"
+
 #define KEY_N_EMBD              "clip.%s.embedding_length"
 #define KEY_N_FF                "clip.%s.feed_forward_length"
 #define KEY_N_BLOCK             "clip.%s.block_count"
+#define KEY_PROJ_DIM            "clip.%s.projection_dim"
 #define KEY_N_HEAD              "clip.%s.attention.head_count"
 #define KEY_LAYER_NORM_EPS      "clip.%s.attention.layer_norm_epsilon"
-#define KEY_PROJ_DIM            "clip.%s.projection_dim"
-#define KEY_TOKENS              "tokenizer.ggml.tokens"
-#define KEY_N_POSITIONS         "clip.text.context_length"
+
+// vision-specific
 #define KEY_IMAGE_SIZE          "clip.vision.image_size"
 #define KEY_PATCH_SIZE          "clip.vision.patch_size"
 #define KEY_IMAGE_MEAN          "clip.vision.image_mean"
 #define KEY_IMAGE_STD           "clip.vision.image_std"
-#define KEY_PROJ_TYPE           "clip.projector_type"
 #define KEY_FEATURE_LAYER       "clip.vision.feature_layer"
+#define KEY_PROJ_SCALE_FACTOR   "clip.vision.projector.scale_factor"
+#define KEY_SPATIAL_MERGE_SIZE  "clip.vision.spatial_merge_size"
 
 #define KEY_MM_PATCH_MERGE_TYPE   "clip.vision.mm_patch_merge_type"
 #define KEY_IMAGE_GRID_PINPOINTS  "clip.vision.image_grid_pinpoints"
 #define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
+#define KEY_WIN_ATTN_PATTERN      "clip.vision.n_wa_pattern"
+#define KEY_ATTN_WINDOW_SIZE      "clip.vision.window_size"
+#define KEY_MINICPMV_VERSION      "clip.minicpmv_version"
+
+// audio-specific
+#define KEY_A_NUM_MEL_BINS      "clip.audio.num_mel_bins"
+#define KEY_A_PROJ_STACK_FACTOR "clip.audio.projector.stack_factor"
 
 
 //
 // tensor name constants
 //
 
-#define TN_TOKEN_EMBD      "%s.token_embd.weight"
 #define TN_POS_EMBD        "%s.position_embd.weight"
 #define TN_CLASS_EMBD      "v.class_embd"
 #define TN_PATCH_EMBD      "v.patch_embd.weight"  // not rename tensor with ".0" postfix for backwrad compat
@@ -60,21 +63,31 @@
 #define TN_ATTN_Q          "%s.blk.%d.attn_q.%s"
 #define TN_ATTN_V          "%s.blk.%d.attn_v.%s"
 #define TN_ATTN_OUTPUT     "%s.blk.%d.attn_out.%s"
+#define TN_ATTN_K_NORM     "%s.blk.%d.attn_k_norm.%s"
+#define TN_ATTN_Q_NORM     "%s.blk.%d.attn_q_norm.%s"
 #define TN_FFN_DOWN        "%s.blk.%d.ffn_down.%s"
+#define TN_FFN_GATE        "%s.blk.%d.ffn_gate.%s"
 #define TN_FFN_UP          "%s.blk.%d.ffn_up.%s"
-#define TN_LN_1            "%s.blk.%d.ln1.%s"
-#define TN_LN_2            "%s.blk.%d.ln2.%s"
+#define TN_FFN_GATE        "%s.blk.%d.ffn_gate.%s"
+#define TN_LN_1            "%s.blk.%d.ln1.%s" // layer norm
+#define TN_LN_2            "%s.blk.%d.ln2.%s" // layer norm
+#define TN_LS_1            "%s.blk.%d.ls1.%s" // layer scale
+#define TN_LS_2            "%s.blk.%d.ls2.%s" // layer scale
 #define TN_LN_PRE          "%s.pre_ln.%s"
 #define TN_LN_POST         "%s.post_ln.%s"
-#define TN_TEXT_PROJ       "text_projection.weight"
-#define TN_VIS_PROJ        "visual_projection.weight"
 #define TN_LLAVA_PROJ      "mm.%d.%s"
 #define TN_MVLM_PROJ_MLP   "mm.model.mlp.%d.%s"
 #define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
 #define TN_MVLM_PROJ_PEG   "mm.model.peg.%d.%s"
 #define TN_IMAGE_NEWLINE   "model.image_newline"
+#define TN_MM_INP_NORM     "mm.input_norm.weight"
 #define TN_MM_INP_PROJ     "mm.input_projection.weight" // gemma3
 #define TN_MM_SOFT_EMB_N   "mm.soft_emb_norm.weight"    // gemma3
+#define TN_MM_PROJECTOR    "mm.model.fc.weight"         // idefics3
+#define TN_MM_PATCH_MERGER "mm.patch_merger.weight"     // mistral small 3.1
+#define TN_TOK_IMG_BREAK   "v.token_embd.img_break"     // pixtral
+#define TN_TOK_GLM_BOI     "adapter.boi"                // glm-edge (these embeddings are not in text model)
+#define TN_TOK_GLM_EOI     "adapter.eoi"                // glm-edge (these embeddings are not in text model)
 
 // mimicpmv
 #define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
@@ -90,18 +103,34 @@
 #define TN_GLM_ADAPTER_D_H_2_4H "adapter.linear.dense_h_to_4h.%s"
 #define TN_GLM_ADAPTER_GATE     "adapter.linear.gate.%s"
 #define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s"
-#define TN_GLM_BOI_W            "adapter.boi"
-#define TN_GLM_EOI_W            "adapter.eoi"
+
+// ultravox
+#define TN_CONV1D       "a.conv1d.%d.%s"
+#define TN_MM_AUDIO_MLP "mm.a.mlp.%d.%s"
+#define TN_MM_AUDIO_FC  "mm.a.fc.%s" // fully connected layer
+#define TN_MM_NORM_PRE  "mm.a.norm_pre.%s"
+#define TN_MM_NORM_MID  "mm.a.norm_mid.%s"
+
+// align x to upper multiple of n
+#define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
 
 enum projector_type {
     PROJECTOR_TYPE_MLP,
     PROJECTOR_TYPE_MLP_NORM,
     PROJECTOR_TYPE_LDP,
     PROJECTOR_TYPE_LDPV2,
-    PROJECTOR_TYPE_RESAMPLER,
+    PROJECTOR_TYPE_MINICPMV,
     PROJECTOR_TYPE_GLM_EDGE,
-    PROJECTOR_TYPE_MERGER,
+    PROJECTOR_TYPE_QWEN2VL,
     PROJECTOR_TYPE_GEMMA3,
+    PROJECTOR_TYPE_IDEFICS3,
+    PROJECTOR_TYPE_PIXTRAL,
+    PROJECTOR_TYPE_QWEN25VL,
+    PROJECTOR_TYPE_ULTRAVOX,
+    PROJECTOR_TYPE_INTERNVL,
+    PROJECTOR_TYPE_LLAMA4,
+    PROJECTOR_TYPE_QWEN2A,
+    PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
     PROJECTOR_TYPE_UNKNOWN,
 };
 
@@ -109,10 +138,18 @@ static std::map PROJECTOR_TYPE_NAMES = {
     { PROJECTOR_TYPE_MLP,       "mlp" },
     { PROJECTOR_TYPE_LDP,       "ldp" },
     { PROJECTOR_TYPE_LDPV2,     "ldpv2"},
-    { PROJECTOR_TYPE_RESAMPLER, "resampler"},
+    { PROJECTOR_TYPE_MINICPMV,  "resampler"},
     { PROJECTOR_TYPE_GLM_EDGE,  "adapter"},
-    { PROJECTOR_TYPE_MERGER,    "qwen2vl_merger"},
+    { PROJECTOR_TYPE_QWEN2VL,   "qwen2vl_merger"},
+    { PROJECTOR_TYPE_QWEN25VL,  "qwen2.5vl_merger"},
     { PROJECTOR_TYPE_GEMMA3,    "gemma3"},
+    { PROJECTOR_TYPE_IDEFICS3,  "idefics3"},
+    { PROJECTOR_TYPE_PIXTRAL,   "pixtral"},
+    { PROJECTOR_TYPE_ULTRAVOX,  "ultravox"},
+    { PROJECTOR_TYPE_INTERNVL,  "internvl"},
+    { PROJECTOR_TYPE_LLAMA4,    "llama4"},
+    { PROJECTOR_TYPE_QWEN2A,    "qwen2a"},
+    { PROJECTOR_TYPE_QWEN25O,   "qwen2.5o"},
 };
 
 static projector_type clip_projector_type_from_string(const std::string & str) {
@@ -132,8 +169,10 @@ struct clip_image_u8 {
     std::vector buf;
 };
 
-// RGB float32 image (NHWC)
-// Memory layout: RGBRGBRGB...
+// For images, buf.size() == nx*ny*3
+//     Memory layout: RGBRGBRGB...
+// For audio, only one channel is used, buf.size() == nx*ny
+//     nx will be n_frames and ny will be n_mel
 struct clip_image_f32 {
     int nx;
     int ny;
@@ -227,6 +266,26 @@ struct clip_image_u8_batch {
 
 struct clip_image_f32_batch {
     std::vector entries;
+    bool is_audio = false;
+
+    // for llava-uhd style models, we need to know the grid size
+    // note: entries.size() == grid_x * grid_y + 1 (one overview image)
+    int grid_x = 0;
+    int grid_y = 0;
+
+    clip_image_f32_batch clone() const {
+        clip_image_f32_batch new_batch{
+            /* entries  */ {},
+            /* is_audio */ is_audio,
+            /* grid_x   */ grid_x,
+            /* grid_y   */ grid_y,
+        };
+        new_batch.entries.reserve(entries.size());
+        for (const auto & entry : entries) {
+            new_batch.entries.emplace_back(new clip_image_f32(*entry));
+        }
+        return new_batch;
+    }
 };
 
 //
@@ -337,6 +396,70 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
     }
 }
 
+//
+// debugging
+//
+
+static void print_tensor_shape(ggml_tensor * t) {
+    printf("%s.shape = [", t->name);
+    for (int i = 0; i < ggml_n_dims(t); ++i) {
+        printf("%" PRId64, t->ne[i]);
+        if (i < ggml_n_dims(t) - 1) {
+            printf(", ");
+        }
+    }
+    printf("]\n");
+}
+
+static void print_tensor_data(ggml_tensor * t, uint8_t * data, int64_t n) {
+    ggml_type type = t->type;
+    int64_t * ne = t->ne;
+    size_t * nb = t->nb;
+    for (int64_t i3 = 0; i3 < ne[3]; i3++) {
+        printf("%s.data: [\n", t->name);
+        for (int64_t i2 = 0; i2 < ne[2]; i2++) {
+            if (i2 == n && ne[2] > 2*n) {
+                printf("     ..., \n");
+                i2 = ne[2] - n;
+            }
+            printf("     [\n");
+            for (int64_t i1 = 0; i1 < ne[1]; i1++) {
+                if (i1 == n && ne[1] > 2*n) {
+                    printf("      ..., \n");
+                    i1 = ne[1] - n;
+                }
+                printf("      [");
+                for (int64_t i0 = 0; i0 < ne[0]; i0++) {
+                    if (i0 == n && ne[0] > 2*n) {
+                        printf("..., ");
+                        i0 = ne[0] - n;
+                    }
+                    size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
+                    float v;
+                    if (type == GGML_TYPE_F16) {
+                        v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
+                    } else if (type == GGML_TYPE_F32) {
+                        v = *(float *) &data[i];
+                    } else if (type == GGML_TYPE_I32) {
+                        v = (float) *(int32_t *) &data[i];
+                    } else if (type == GGML_TYPE_I16) {
+                        v = (float) *(int16_t *) &data[i];
+                    } else if (type == GGML_TYPE_I8) {
+                        v = (float) *(int8_t *) &data[i];
+                    } else {
+                        GGML_ABORT("fatal error");
+                    }
+                    printf("%8.4f", v);
+                    if (i0 < ne[0] - 1) printf(", ");
+                }
+                printf("],\n");
+            }
+            printf("     ],\n");
+        }
+        printf("    ]\n");
+    }
+}
+
 //
 // API used internally with mtmd
 //
diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp
new file mode 100644
index 000000000..30283d6f1
--- /dev/null
+++ b/tools/mtmd/clip.cpp
@@ -0,0 +1,4157 @@
+// NOTE: This is modified from clip.cpp only for LLaVA,
+// so there might be still unnecessary artifacts hanging around
+// I'll gradually clean and extend it
+// Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch
+#include "clip.h"
+#include "clip-impl.h"
+#include "ggml.h"
+#include "ggml-cpp.h"
+#include "ggml-cpu.h"
+#include "ggml-alloc.h"
+#include "ggml-backend.h"
+#include "gguf.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
+
+enum ffn_op_type {
+    FFN_GELU,
+    FFN_GELU_ERF,
+    FFN_SILU,
+    FFN_GELU_QUICK,
+};
+
+enum norm_type {
+    NORM_TYPE_NORMAL,
+    NORM_TYPE_RMS,
+};
+
+//#define CLIP_DEBUG_FUNCTIONS
+
+#ifdef CLIP_DEBUG_FUNCTIONS
+static void clip_image_write_image_to_ppm(const clip_image_u8& img, const std::string& filename) {
+    std::ofstream file(filename, std::ios::binary);
+    if (!file.is_open()) {
+        LOG_ERR("Failed to open file for writing: %s\n", filename.c_str());
+        return;
+    }
+
+    // PPM header: P6 format, width, height, and max color value
+    file << "P6\n" << img.nx << " " << img.ny << "\n255\n";
+
+    // Write pixel data
+    for (size_t i = 0; i < img.buf.size(); i += 3) {
+        // PPM expects binary data in RGB format, which matches our image buffer
+        file.write(reinterpret_cast(&img.buf[i]), 3);
+    }
+
+    file.close();
+}
+
+static void clip_image_save_to_bmp(const clip_image_u8& img, const std::string& filename) {
+    std::ofstream file(filename, std::ios::binary);
+    if (!file.is_open()) {
+        LOG_ERR("Failed to open file for writing: %s\n", filename.c_str());
+        return;
+    }
+
+    int fileSize = 54 + 3 * img.nx * img.ny; // File header + info header + pixel data
+    int bytesPerPixel = 3;
+    int widthInBytes = img.nx * bytesPerPixel;
+    int paddingAmount = (4 - (widthInBytes % 4)) % 4;
+    int stride = widthInBytes + paddingAmount;
+
+    // Bitmap file header
+    unsigned char fileHeader[14] = {
+        'B','M',     // Signature
+        0,0,0,0,    // Image file size in bytes
+        0,0,0,0,    // Reserved
+        54,0,0,0    // Start of pixel array
+    };
+
+    // Total file size
+    fileSize = 54 + (stride * img.ny);
+    fileHeader[2] = (unsigned char)(fileSize);
+    fileHeader[3] = (unsigned char)(fileSize >> 8);
+    fileHeader[4] = (unsigned char)(fileSize >> 16);
+    fileHeader[5] = (unsigned char)(fileSize >> 24);
+
+    // Bitmap information header (BITMAPINFOHEADER)
+    unsigned char infoHeader[40] = {
+        40,0,0,0,   // Size of this header (40 bytes)
+        0,0,0,0,    // Image width
+        0,0,0,0,    // Image height
+        1,0,        // Number of color planes
+        24,0,       // Bits per pixel
+        0,0,0,0,    // No compression
+        0,0,0,0,    // Image size (can be 0 for no compression)
+        0,0,0,0,    // X pixels per meter (not specified)
+        0,0,0,0,    // Y pixels per meter (not specified)
+        0,0,0,0,    // Total colors (color table not used)
+        0,0,0,0     // Important colors (all are important)
+    };
+
+    // Width and height in the information header
+    infoHeader[4] = (unsigned char)(img.nx);
+    infoHeader[5] = (unsigned char)(img.nx >> 8);
+    infoHeader[6] = (unsigned char)(img.nx >> 16);
+    infoHeader[7] = (unsigned char)(img.nx >> 24);
+    infoHeader[8] = (unsigned char)(img.ny);
+    infoHeader[9] = (unsigned char)(img.ny >> 8);
+    infoHeader[10] = (unsigned char)(img.ny >> 16);
+    infoHeader[11] = (unsigned char)(img.ny >> 24);
+
+    // Write file headers
+    file.write(reinterpret_cast(fileHeader), sizeof(fileHeader));
+    file.write(reinterpret_cast(infoHeader), sizeof(infoHeader));
+
+    // Pixel data
+    std::vector padding(3, 0); // Max padding size to be added to each row
+    for (int y = img.ny - 1; y >= 0; --y) { // BMP files are stored bottom-to-top
+        for (int x = 0; x < img.nx; ++x) {
+            // Each pixel
+            size_t pixelIndex = (y * img.nx + x) * 3;
+            unsigned char pixel[3] = {
+                img.buf[pixelIndex + 2], // BMP stores pixels in BGR format
+                img.buf[pixelIndex + 1],
+                img.buf[pixelIndex]
+            };
+            file.write(reinterpret_cast(pixel), 3);
+        }
+        // Write padding for the row
+        file.write(reinterpret_cast(padding.data()), paddingAmount);
+    }
+
+    file.close();
+}
+
+// debug function to convert f32 to u8
+static void clip_image_convert_f32_to_u8(const clip_image_f32& src, clip_image_u8& dst) {
+    dst.nx = src.nx;
+    dst.ny = src.ny;
+    dst.buf.resize(3 * src.nx * src.ny);
+    for (size_t i = 0; i < src.buf.size(); ++i) {
+        dst.buf[i] = static_cast(std::min(std::max(int(src.buf[i] * 255.0f), 0), 255));
+    }
+}
+#endif
+
+
+//
+// clip layers
+//
+
+enum patch_merge_type {
+    PATCH_MERGE_FLAT,
+    PATCH_MERGE_SPATIAL_UNPAD,
+};
+
+struct clip_hparams {
+    int32_t image_size;
+    int32_t patch_size;
+    int32_t n_embd;
+    int32_t n_ff;
+    int32_t projection_dim;
+    int32_t n_head;
+    int32_t n_layer;
+    int32_t proj_scale_factor = 0; // idefics3
+
+    float image_mean[3];
+    float image_std[3];
+
+    // for models using dynamic image size, we need to have a smaller image size to warmup
+    // otherwise, user will get OOM everytime they load the model
+    int32_t warmup_image_size = 0;
+    int32_t warmup_audio_size = 3000;
+
+    ffn_op_type ffn_op = FFN_GELU;
+
+    patch_merge_type mm_patch_merge_type = PATCH_MERGE_FLAT;
+
+    float eps = 1e-6;
+    float rope_theta = 0.0;
+
+    std::vector image_res_candidates; // for llava-uhd style models
+    int32_t image_crop_resolution;
+    std::unordered_set vision_feature_layer;
+    int32_t attn_window_size = 0;
+    int32_t n_wa_pattern = 0;
+    int32_t spatial_merge_size = 0;
+
+    // audio
+    int32_t n_mel_bins = 0; // whisper preprocessor
+    int32_t proj_stack_factor = 0; // ultravox
+
+    // legacy
+    bool has_llava_projector = false;
+    int minicpmv_version = 0;
+};
+
+struct clip_layer {
+    // attention
+    ggml_tensor * k_w = nullptr;
+    ggml_tensor * k_b = nullptr;
+    ggml_tensor * q_w = nullptr;
+    ggml_tensor * q_b = nullptr;
+    ggml_tensor * v_w = nullptr;
+    ggml_tensor * v_b = nullptr;
+
+    ggml_tensor * o_w = nullptr;
+    ggml_tensor * o_b = nullptr;
+
+    ggml_tensor * k_norm = nullptr;
+    ggml_tensor * q_norm = nullptr;
+
+    // layernorm 1
+    ggml_tensor * ln_1_w = nullptr;
+    ggml_tensor * ln_1_b = nullptr;
+
+    ggml_tensor * ff_up_w = nullptr;
+    ggml_tensor * ff_up_b = nullptr;
+    ggml_tensor * ff_gate_w = nullptr;
+    ggml_tensor * ff_gate_b = nullptr;
+    ggml_tensor * ff_down_w = nullptr;
+    ggml_tensor * ff_down_b = nullptr;
+
+    // layernorm 2
+    ggml_tensor * ln_2_w = nullptr;
+    ggml_tensor * ln_2_b = nullptr;
+
+    // layer scale (no bias)
+    ggml_tensor * ls_1_w = nullptr;
+    ggml_tensor * ls_2_w = nullptr;
+};
+
+struct clip_model {
+    clip_modality modality = CLIP_MODALITY_VISION;
+    projector_type proj_type = PROJECTOR_TYPE_MLP;
+    clip_hparams hparams;
+
+    // embeddings
+    ggml_tensor * class_embedding = nullptr;
+    ggml_tensor * patch_embeddings_0 = nullptr;
+    ggml_tensor * patch_embeddings_1 = nullptr;  // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL)
+    ggml_tensor * patch_bias = nullptr;
+    ggml_tensor * position_embeddings = nullptr;
+
+    ggml_tensor * pre_ln_w = nullptr;
+    ggml_tensor * pre_ln_b = nullptr;
+
+    std::vector layers;
+
+    ggml_tensor * post_ln_w;
+    ggml_tensor * post_ln_b;
+
+    ggml_tensor * projection; // TODO: rename it to fc (fully connected layer)
+    ggml_tensor * mm_fc_w;
+    ggml_tensor * mm_fc_b;
+
+    // LLaVA projection
+    ggml_tensor * mm_input_norm_w = nullptr;
+    ggml_tensor * mm_0_w = nullptr;
+    ggml_tensor * mm_0_b = nullptr;
+    ggml_tensor * mm_2_w = nullptr;
+    ggml_tensor * mm_2_b = nullptr;
+
+    ggml_tensor * image_newline = nullptr;
+
+    // Yi type models with mlp+normalization projection
+    ggml_tensor * mm_1_w = nullptr; // Yi type models have 0, 1, 3, 4
+    ggml_tensor * mm_1_b = nullptr;
+    ggml_tensor * mm_3_w = nullptr;
+    ggml_tensor * mm_3_b = nullptr;
+    ggml_tensor * mm_4_w = nullptr;
+    ggml_tensor * mm_4_b = nullptr;
+
+    // GLMV-Edge projection
+    ggml_tensor * mm_model_adapter_conv_w = nullptr;
+    ggml_tensor * mm_model_adapter_conv_b = nullptr;
+    ggml_tensor * mm_glm_tok_boi = nullptr;
+    ggml_tensor * mm_glm_tok_eoi = nullptr;
+
+    // MobileVLM projection
+    ggml_tensor * mm_model_mlp_1_w = nullptr;
+    ggml_tensor * mm_model_mlp_1_b = nullptr;
+    ggml_tensor * mm_model_mlp_3_w = nullptr;
+    ggml_tensor * mm_model_mlp_3_b = nullptr;
+    ggml_tensor * mm_model_block_1_block_0_0_w = nullptr;
+    ggml_tensor * mm_model_block_1_block_0_1_w = nullptr;
+    ggml_tensor * mm_model_block_1_block_0_1_b = nullptr;
+    ggml_tensor * mm_model_block_1_block_1_fc1_w = nullptr;
+    ggml_tensor * mm_model_block_1_block_1_fc1_b = nullptr;
+    ggml_tensor * mm_model_block_1_block_1_fc2_w = nullptr;
+    ggml_tensor * mm_model_block_1_block_1_fc2_b = nullptr;
+    ggml_tensor * mm_model_block_1_block_2_0_w = nullptr;
+    ggml_tensor * mm_model_block_1_block_2_1_w = nullptr;
+    ggml_tensor * mm_model_block_1_block_2_1_b = nullptr;
+    ggml_tensor * mm_model_block_2_block_0_0_w = nullptr;
+    ggml_tensor * mm_model_block_2_block_0_1_w = nullptr;
+    ggml_tensor * mm_model_block_2_block_0_1_b = nullptr;
+    ggml_tensor * mm_model_block_2_block_1_fc1_w = nullptr;
+    ggml_tensor * mm_model_block_2_block_1_fc1_b = nullptr;
+    ggml_tensor * mm_model_block_2_block_1_fc2_w = nullptr;
+    ggml_tensor * mm_model_block_2_block_1_fc2_b = nullptr;
+    ggml_tensor * mm_model_block_2_block_2_0_w = nullptr;
+    ggml_tensor * mm_model_block_2_block_2_1_w = nullptr;
+    ggml_tensor * mm_model_block_2_block_2_1_b = nullptr;
+
+    // MobileVLM_V2 projection
+    ggml_tensor * mm_model_mlp_0_w = nullptr;
+    ggml_tensor * mm_model_mlp_0_b = nullptr;
+    ggml_tensor * mm_model_mlp_2_w = nullptr;
+    ggml_tensor * mm_model_mlp_2_b = nullptr;
+    ggml_tensor * mm_model_peg_0_w = nullptr;
+    ggml_tensor * mm_model_peg_0_b = nullptr;
+
+    // MINICPMV projection
+    ggml_tensor * mm_model_pos_embed_k = nullptr;
+    ggml_tensor * mm_model_query = nullptr;
+    ggml_tensor * mm_model_proj = nullptr;
+    ggml_tensor * mm_model_kv_proj = nullptr;
+    ggml_tensor * mm_model_attn_q_w = nullptr;
+    ggml_tensor * mm_model_attn_q_b = nullptr;
+    ggml_tensor * mm_model_attn_k_w = nullptr;
+    ggml_tensor * mm_model_attn_k_b = nullptr;
+    ggml_tensor * mm_model_attn_v_w = nullptr;
+    ggml_tensor * mm_model_attn_v_b = nullptr;
+    ggml_tensor * mm_model_attn_o_w = nullptr;
+    ggml_tensor * mm_model_attn_o_b = nullptr;
+    ggml_tensor * mm_model_ln_q_w = nullptr;
+    ggml_tensor * mm_model_ln_q_b = nullptr;
+    ggml_tensor * mm_model_ln_kv_w = nullptr;
+    ggml_tensor * mm_model_ln_kv_b = nullptr;
+    ggml_tensor * mm_model_ln_post_w = nullptr;
+    ggml_tensor * mm_model_ln_post_b = nullptr;
+
+    // gemma3
+    ggml_tensor * mm_input_proj_w = nullptr;
+    ggml_tensor * mm_soft_emb_norm_w = nullptr;
+
+    // pixtral
+    ggml_tensor * token_embd_img_break = nullptr;
+    ggml_tensor * mm_patch_merger_w = nullptr;
+
+    // ultravox / whisper encoder
+    ggml_tensor * conv1d_1_w = nullptr;
+    ggml_tensor * conv1d_1_b = nullptr;
+    ggml_tensor * conv1d_2_w = nullptr;
+    ggml_tensor * conv1d_2_b = nullptr;
+    ggml_tensor * mm_norm_pre_w = nullptr;
+    ggml_tensor * mm_norm_mid_w = nullptr;
+};
+
+struct clip_ctx {
+    clip_model model;
+
+    gguf_context_ptr ctx_gguf;
+    ggml_context_ptr ctx_data;
+
+    std::vector buf_compute_meta;
+
+    std::vector backend_ptrs;
+    std::vector backend_buft;
+
+    ggml_backend_t backend;
+    ggml_backend_t backend_cpu;
+    ggml_backend_buffer_ptr buf;
+
+    int max_nodes = 8192;
+    ggml_backend_sched_ptr sched;
+
+    // for debugging
+    bool debug_graph = false;
+    std::vector debug_print_tensors;
+
+    clip_ctx(clip_context_params & ctx_params) {
+        debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr;
+        backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
+        if (!backend_cpu) {
+            throw std::runtime_error("failed to initialize CPU backend");
+        }
+        backend = ctx_params.use_gpu
+                    ? ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_GPU, nullptr)
+                    : nullptr;
+
+        if (backend) {
+            LOG_INF("%s: CLIP using %s backend\n", __func__, ggml_backend_name(backend));
+            backend_ptrs.push_back(backend);
+            backend_buft.push_back(ggml_backend_get_default_buffer_type(backend));
+        } else {
+            backend = backend_cpu;
+            LOG_INF("%s: CLIP using CPU backend\n", __func__);
+        }
+
+        backend_ptrs.push_back(backend_cpu);
+        backend_buft.push_back(ggml_backend_get_default_buffer_type(backend_cpu));
+
+        sched.reset(
+            ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), 8192, false, true)
+        );
+    }
+
+    ~clip_ctx() {
+        ggml_backend_free(backend);
+        if (backend != backend_cpu) {
+            ggml_backend_free(backend_cpu);
+        }
+    }
+
+    // this function is added so that we don't change too much of the existing code
+    projector_type proj_type() const {
+        return model.proj_type;
+    }
+};
+
+struct clip_graph {
+    clip_ctx * ctx;
+    const clip_model & model;
+    const clip_hparams & hparams;
+
+    // we only support single image per batch
+    const clip_image_f32 & img;
+
+    const int patch_size;
+    const int n_patches_x;
+    const int n_patches_y;
+    const int n_patches;
+    const int n_embd;
+    const int n_head;
+    const int d_head;
+    const int n_layer;
+    const float eps;
+    const float kq_scale;
+
+    ggml_context_ptr ctx0_ptr;
+    ggml_context * ctx0;
+    ggml_cgraph * gf;
+
+    clip_graph(clip_ctx * ctx, const clip_image_f32 & img) :
+            ctx(ctx),
+            model(ctx->model),
+            hparams(model.hparams),
+            img(img),
+            patch_size(hparams.patch_size),
+            n_patches_x(img.nx / patch_size),
+            n_patches_y(img.ny / patch_size),
+            n_patches(n_patches_x * n_patches_y),
+            n_embd(hparams.n_embd),
+            n_head(hparams.n_head),
+            d_head(n_embd / n_head),
+            n_layer(hparams.n_layer),
+            eps(hparams.eps),
+            kq_scale(1.0f / sqrtf((float)d_head)) {
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ ctx->buf_compute_meta.size(),
+            /*.mem_buffer =*/ ctx->buf_compute_meta.data(),
+            /*.no_alloc   =*/ true,
+        };
+        ctx0_ptr.reset(ggml_init(params));
+        ctx0 = ctx0_ptr.get();
+        gf = ggml_new_graph_custom(ctx0, ctx->max_nodes, false);
+    }
+
+    ggml_cgraph * build_siglip() {
+        ggml_tensor * inp = build_inp();
+        ggml_tensor * cur = build_vit(
+                                inp, n_patches,
+                                NORM_TYPE_NORMAL,
+                                hparams.ffn_op,
+                                model.position_embeddings,
+                                nullptr);
+
+        if (ctx->proj_type() == PROJECTOR_TYPE_GEMMA3) {
+            const int batch_size = 1;
+            GGML_ASSERT(n_patches_x == n_patches_y);
+            const int patches_per_image = n_patches_x;
+            const int kernel_size = hparams.proj_scale_factor;
+
+            cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+            cur = ggml_reshape_4d(ctx0, cur, patches_per_image, patches_per_image, n_embd, batch_size);
+
+            // doing a pool2d to reduce the number of output tokens
+            cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0);
+            cur = ggml_reshape_3d(ctx0, cur, cur->ne[0] * cur->ne[0], n_embd, batch_size);
+            cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+
+            // apply norm before projection
+            cur = ggml_rms_norm(ctx0, cur, eps);
+            cur = ggml_mul(ctx0, cur, model.mm_soft_emb_norm_w);
+
+            // apply projection
+            cur = ggml_mul_mat(ctx0,
+                ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)),
+                cur);
+
+        } else if (ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3) {
+            // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
+
+            const int scale_factor = model.hparams.proj_scale_factor;
+            const int n_embd = cur->ne[0];
+            const int seq    = cur->ne[1];
+            const int bsz    = 1; // batch size, always 1 for now since we don't support batching
+            const int height = std::sqrt(seq);
+            const int width  = std::sqrt(seq);
+            GGML_ASSERT(scale_factor != 0);
+            cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz);
+            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+            cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur),
+                n_embd * scale_factor * scale_factor,
+                height / scale_factor,
+                width / scale_factor,
+                bsz);
+            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+            cur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, cur),
+                n_embd * scale_factor * scale_factor,
+                seq / (scale_factor * scale_factor),
+                bsz);
+
+            cur = ggml_mul_mat(ctx0, model.projection, cur);
+        } else {
+            GGML_ABORT("SigLIP: Unsupported projector type");
+        }
+
+        // build the graph
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    ggml_cgraph * build_pixtral() {
+        const int n_merge = hparams.spatial_merge_size;
+
+        // 2D input positions
+        ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
+        ggml_set_name(pos_h, "pos_h");
+        ggml_set_input(pos_h);
+
+        ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
+        ggml_set_name(pos_w, "pos_w");
+        ggml_set_input(pos_w);
+
+        auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
+            return build_rope_2d(ctx0, cur, pos_h, pos_w, hparams.rope_theta, true);
+        };
+
+        ggml_tensor * inp = build_inp();
+        ggml_tensor * cur = build_vit(
+                                inp, n_patches,
+                                NORM_TYPE_RMS,
+                                hparams.ffn_op,
+                                nullptr, // no learned pos embd
+                                add_pos);
+
+        // mistral small 3.1 patch merger
+        // ref: https://github.com/huggingface/transformers/blob/7a3e208892c06a5e278144eaf38c8599a42f53e7/src/transformers/models/mistral3/modeling_mistral3.py#L67
+        if (model.mm_patch_merger_w) {
+            GGML_ASSERT(hparams.spatial_merge_size > 0);
+
+            cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.mm_input_norm_w);
+
+            // reshape image tokens to 2D grid
+            cur = ggml_reshape_3d(ctx0, cur, n_embd, n_patches_x, n_patches_y);
+            cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // [x, y, n_embd]
+            cur = ggml_cont(ctx0, cur);
+
+            // torch.nn.functional.unfold is just an im2col under the hood
+            // we just need a dummy kernel to make it work
+            ggml_tensor * kernel = ggml_view_3d(ctx0, cur, n_merge, n_merge, cur->ne[2], 0, 0, 0);
+            cur = ggml_im2col(ctx0, kernel, cur, n_merge, n_merge, 0, 0, 1, 1, true, inp->type);
+
+            // project to n_embd
+            cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
+            cur = ggml_mul_mat(ctx0, model.mm_patch_merger_w, cur);
+        }
+
+        // LlavaMultiModalProjector (always using GELU activation)
+        {
+            cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
+            if (model.mm_1_b) {
+                cur = ggml_add(ctx0, cur, model.mm_1_b);
+            }
+
+            cur = ggml_gelu(ctx0, cur);
+            cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
+            if (model.mm_2_b) {
+                cur = ggml_add(ctx0, cur, model.mm_2_b);
+            }
+        }
+
+        // arrangement of the [IMG_BREAK] token
+        {
+            // not efficient, but works
+            // the trick is to view the embeddings as a 3D tensor with shape [n_embd, n_patches_per_row, n_rows]
+            // and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
+            // after the concatenation, we have a tensor with shape [n_embd, n_patches_per_row + 1, n_rows]
+
+            const int p_y             = n_merge > 0 ? n_patches_y / n_merge : n_patches_y;
+            const int p_x             = n_merge > 0 ? n_patches_x / n_merge : n_patches_x;
+            const int p_total         = p_x * p_y;
+            const int n_embd_text     = cur->ne[0];
+            const int n_tokens_output = p_total + p_y - 1; // one [IMG_BREAK] per row, except the last row
+
+            ggml_tensor * tmp = ggml_reshape_3d(ctx0, cur, n_embd_text, p_x, p_y);
+            ggml_tensor * tok = ggml_new_tensor_3d(ctx0, tmp->type, n_embd_text, 1, p_y);
+            tok = ggml_scale(ctx0, tok, 0.0); // clear the tensor
+            tok = ggml_add(ctx0, tok, model.token_embd_img_break);
+            tmp = ggml_concat(ctx0, tmp, tok, 1);
+            cur = ggml_view_2d(ctx0, tmp,
+                n_embd_text, n_tokens_output,
+                ggml_row_size(tmp->type, n_embd_text), 0);
+        }
+
+        // build the graph
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    // Qwen2VL and Qwen2.5VL use M-RoPE
+    ggml_cgraph * build_qwen2vl() {
+        GGML_ASSERT(model.patch_bias == nullptr);
+        GGML_ASSERT(model.class_embedding == nullptr);
+
+        const int batch_size       = 1;
+        const bool use_window_attn = hparams.n_wa_pattern > 0;
+        const int n_wa_pattern     = hparams.n_wa_pattern;
+        const int n_pos            = n_patches;
+        const int num_position_ids = n_pos * 4; // m-rope requires 4 dim per position
+
+        norm_type norm_t = ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL
+            ? NORM_TYPE_RMS // qwen 2.5 vl
+            : NORM_TYPE_NORMAL; // qwen 2 vl
+
+        int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
+
+        ggml_tensor * inp_raw = build_inp_raw();
+        ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
+
+        GGML_ASSERT(img.nx % (patch_size * 2) == 0);
+        GGML_ASSERT(img.ny % (patch_size * 2) == 0);
+
+        // second conv dimension
+        {
+            auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
+            inp = ggml_add(ctx0, inp, inp_1);
+
+            inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3));  // [w, h, c, b] -> [c, w, h, b]
+            inp = ggml_reshape_4d(
+                ctx0, inp,
+                n_embd * 2, n_patches_x / 2, n_patches_y, batch_size);
+            inp = ggml_reshape_4d(
+                ctx0, inp,
+                n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2));
+            inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
+            inp = ggml_reshape_3d(
+                ctx0, inp,
+                n_embd, n_patches_x * n_patches_y, batch_size);
+        }
+
+        ggml_tensor * inpL           = inp;
+        ggml_tensor * window_mask    = nullptr;
+        ggml_tensor * window_idx     = nullptr;
+        ggml_tensor * inv_window_idx = nullptr;
+
+        ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
+        ggml_set_name(positions, "positions");
+        ggml_set_input(positions);
+
+        // pre-layernorm
+        if (model.pre_ln_w) {
+            inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1);
+        }
+
+        if (use_window_attn) {
+            // handle window attention inputs
+            inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4);
+            ggml_set_name(inv_window_idx, "inv_window_idx");
+            ggml_set_input(inv_window_idx);
+            // mask for window attention
+            window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_pos, n_pos);
+            ggml_set_name(window_mask, "window_mask");
+            ggml_set_input(window_mask);
+
+            // inpL shape: [n_embd, n_patches_x * n_patches_y, batch_size]
+            GGML_ASSERT(batch_size == 1);
+            inpL = ggml_reshape_2d(ctx0, inpL, n_embd * 4, n_patches_x * n_patches_y * batch_size / 4);
+            inpL = ggml_get_rows(ctx0, inpL, inv_window_idx);
+            inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_patches_x * n_patches_y, batch_size);
+        }
+
+        // loop over layers
+        for (int il = 0; il < n_layer; il++) {
+            auto & layer = model.layers[il];
+            const bool full_attn = use_window_attn ? (il + 1) % n_wa_pattern == 0 : true;
+
+            ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
+
+            // layernorm1
+            cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il);
+            cb(cur, "ln1", il);
+
+            // self-attention
+            {
+                ggml_tensor * Qcur = ggml_add(ctx0,
+                    ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b);
+                ggml_tensor * Kcur = ggml_add(ctx0,
+                    ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b);
+                ggml_tensor * Vcur = ggml_add(ctx0,
+                    ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_patches);
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                // apply M-RoPE
+                Qcur = ggml_rope_multi(
+                    ctx0, Qcur, positions, nullptr,
+                    d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
+                Kcur = ggml_rope_multi(
+                    ctx0, Kcur, positions, nullptr,
+                    d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
+
+                cb(Qcur, "Qcur_rope", il);
+                cb(Kcur, "Kcur_rope", il);
+
+                ggml_tensor * attn_mask = full_attn ? nullptr : window_mask;
+
+                cur = build_attn(layer.o_w, layer.o_b,
+                    Qcur, Kcur, Vcur, attn_mask, kq_scale, il);
+                cb(cur, "attn_out", il);
+            }
+
+            // re-add the layer input, e.g., residual
+            cur = ggml_add(ctx0, cur, inpL);
+
+            inpL = cur; // inpL = residual, cur = hidden_states
+
+            cb(cur, "ffn_inp", il);
+
+            // layernorm2
+            cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il);
+            cb(cur, "ffn_inp_normed", il);
+
+            // ffn
+            cur = build_ffn(cur,
+                layer.ff_up_w, layer.ff_up_b,
+                layer.ff_gate_w, layer.ff_gate_b,
+                layer.ff_down_w, layer.ff_down_b,
+                hparams.ffn_op, il);
+
+            cb(cur, "ffn_out", il);
+
+            // residual 2
+            cur = ggml_add(ctx0, inpL, cur);
+            cb(cur, "layer_out", il);
+
+            inpL = cur;
+        }
+
+        // post-layernorm
+        if (model.post_ln_w) {
+            inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer);
+        }
+
+        // multimodal projection
+        ggml_tensor * embeddings = inpL;
+        embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size);
+
+        embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
+        embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
+
+        // GELU activation
+        embeddings = ggml_gelu(ctx0, embeddings);
+
+        // Second linear layer
+        embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
+        embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
+
+        if (use_window_attn) {
+            window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4);
+            ggml_set_name(window_idx, "window_idx");
+            ggml_set_input(window_idx);
+
+            // embeddings shape: [n_embd, n_patches_x * n_patches_y, batch_size]
+            GGML_ASSERT(batch_size == 1);
+            embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, n_patches_x * n_patches_y / 4);
+            embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
+            embeddings = ggml_reshape_3d(ctx0, embeddings, hparams.projection_dim, n_patches_x * n_patches_y / 4, batch_size);
+        }
+
+        // build the graph
+        ggml_build_forward_expand(gf, embeddings);
+
+        return gf;
+    }
+
+    ggml_cgraph * build_minicpmv() {
+        const int batch_size = 1;
+
+        GGML_ASSERT(model.class_embedding == nullptr);
+        const int n_pos = n_patches;
+
+        // position embeddings for the projector (not for ViT)
+        int n_output_dim = clip_n_mmproj_embd(ctx);
+        ggml_tensor * pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_output_dim, n_pos, batch_size);
+        ggml_set_name(pos_embed, "pos_embed");
+        ggml_set_input(pos_embed);
+
+        // for selecting learned pos embd, used by ViT
+        struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
+        ggml_set_name(positions, "positions");
+        ggml_set_input(positions);
+
+        ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, model.position_embeddings, positions);
+
+        ggml_tensor * inp = build_inp();
+        ggml_tensor * embeddings = build_vit(
+                                inp, n_patches,
+                                NORM_TYPE_NORMAL,
+                                hparams.ffn_op,
+                                learned_pos_embd,
+                                nullptr);
+
+        // resampler projector (it is just another transformer)
+
+        ggml_tensor * q = model.mm_model_query;
+        ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings);
+
+        // norm
+        q = build_norm(q, model.mm_model_ln_q_w, model.mm_model_ln_q_b, NORM_TYPE_NORMAL, eps, -1);
+        v = build_norm(v, model.mm_model_ln_kv_w, model.mm_model_ln_kv_b, NORM_TYPE_NORMAL, eps, -1);
+
+        // k = v + pos_embed
+        ggml_tensor * k = ggml_add(ctx0, v, pos_embed);
+
+        // attention
+        {
+            int n_embd = clip_n_mmproj_embd(ctx);
+            const int d_head = 128;
+            int n_head = n_embd/d_head;
+            int num_query = 96;
+            if (ctx->model.hparams.minicpmv_version == 2) {
+                num_query = 96;
+            } else if (ctx->model.hparams.minicpmv_version == 3) {
+                num_query = 64;
+            } else if (ctx->model.hparams.minicpmv_version == 4) {
+                num_query = 64;
+            }
+
+            ggml_tensor * Q = ggml_add(ctx0,
+                ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q),
+                model.mm_model_attn_q_b);
+            ggml_tensor * K = ggml_add(ctx0,
+                ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k),
+                model.mm_model_attn_k_b);
+            ggml_tensor * V = ggml_add(ctx0,
+                ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v),
+                model.mm_model_attn_v_b);
+
+            Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_query);
+            K = ggml_reshape_3d(ctx0, K, d_head, n_head, n_pos);
+            V = ggml_reshape_3d(ctx0, V, d_head, n_head, n_pos);
+
+            cb(Q, "resampler_Q", -1);
+            cb(K, "resampler_K", -1);
+            cb(V, "resampler_V", -1);
+
+            embeddings = build_attn(
+                model.mm_model_attn_o_w,
+                model.mm_model_attn_o_b,
+                Q, K, V, nullptr, kq_scale, -1);
+            cb(embeddings, "resampler_attn_out", -1);
+        }
+        // layernorm
+        embeddings = build_norm(embeddings, model.mm_model_ln_post_w, model.mm_model_ln_post_b, NORM_TYPE_NORMAL, eps, -1);
+
+        // projection
+        embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings);
+
+        // build the graph
+        ggml_build_forward_expand(gf, embeddings);
+
+        return gf;
+    }
+
+    ggml_cgraph * build_internvl() {
+        GGML_ASSERT(model.class_embedding != nullptr);
+        GGML_ASSERT(model.position_embeddings != nullptr);
+
+        const int n_pos = n_patches + 1;
+        ggml_tensor * inp = build_inp();
+
+        // add CLS token
+        inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
+
+        // The larger models use a different ViT, which uses RMS norm instead of layer norm
+        // ref: https://github.com/ggml-org/llama.cpp/pull/13443#issuecomment-2869786188
+        norm_type norm_t = (hparams.n_embd == 3200 && hparams.n_layer == 45)
+            ? NORM_TYPE_RMS // 6B ViT (Used by InternVL 2.5/3 - 26B, 38B, 78B)
+            : NORM_TYPE_NORMAL; // 300M ViT (Used by all smaller InternVL models)
+
+        ggml_tensor * cur = build_vit(
+                                inp, n_pos,
+                                norm_t,
+                                hparams.ffn_op,
+                                model.position_embeddings,
+                                nullptr);
+
+        // remove CLS token
+        cur = ggml_view_2d(ctx0, cur,
+            n_embd, n_patches,
+            ggml_row_size(cur->type, n_embd), 0);
+
+        // pixel shuffle
+        {
+            const int scale_factor = model.hparams.proj_scale_factor;
+            const int bsz    = 1; // batch size, always 1 for now since we don't support batching
+            const int height = n_patches_y;
+            const int width  = n_patches_x;
+            GGML_ASSERT(scale_factor > 0);
+            cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, height / scale_factor, width, bsz);
+            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+            cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur),
+                n_embd * scale_factor * scale_factor,
+                height / scale_factor,
+                width / scale_factor,
+                bsz);
+            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+            // flatten to 2D
+            cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur),
+                n_embd * scale_factor * scale_factor,
+                cur->ne[1] * cur->ne[2]);
+        }
+
+        // projector (always using GELU activation)
+        {
+            // projector LayerNorm uses pytorch's default eps = 1e-5
+            // ref: https://huggingface.co/OpenGVLab/InternVL3-8B-Instruct/blob/a34d3e4e129a5856abfd6aa6de79776484caa14e/modeling_internvl_chat.py#L79
+            cur = build_norm(cur, model.mm_0_w, model.mm_0_b, NORM_TYPE_NORMAL, 1e-5, -1);
+            cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
+            cur = ggml_add(ctx0, cur, model.mm_1_b);
+            cur = ggml_gelu(ctx0, cur);
+            cur = ggml_mul_mat(ctx0, model.mm_3_w, cur);
+            cur = ggml_add(ctx0, cur, model.mm_3_b);
+        }
+
+        // build the graph
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    ggml_cgraph * build_llama4() {
+        GGML_ASSERT(model.class_embedding != nullptr);
+        GGML_ASSERT(model.position_embeddings != nullptr);
+
+        const int n_pos = n_patches + 1; // +1 for [CLS]
+
+        // 2D input positions
+        ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
+        ggml_set_name(pos_h, "pos_h");
+        ggml_set_input(pos_h);
+
+        ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
+        ggml_set_name(pos_w, "pos_w");
+        ggml_set_input(pos_w);
+
+        ggml_tensor * inp = build_inp_raw();
+
+        // Llama4UnfoldConvolution
+        {
+            ggml_tensor * kernel = ggml_reshape_4d(ctx0, model.patch_embeddings_0,
+                                                    patch_size, patch_size, 3, n_embd);
+            inp = ggml_im2col(ctx0, kernel, inp, patch_size, patch_size, 0, 0, 1, 1, true, inp->type);
+            inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp);
+            inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches);
+            cb(inp, "patch_conv", -1);
+        }
+
+        // add CLS token
+        inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
+
+        // build ViT with 2D position embeddings
+        auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
+            // first half is X axis and second half is Y axis
+            // ref: https://github.com/huggingface/transformers/blob/40a493c7ed4f19f08eadb0639cf26d49bfa5e180/src/transformers/models/llama4/modeling_llama4.py#L1312
+            // ref: https://github.com/Blaizzy/mlx-vlm/blob/a57156aa87b33cca6e5ee6cfc14dd4ef8f611be6/mlx_vlm/models/llama4/vision.py#L441
+            return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false);
+        };
+        ggml_tensor * cur = build_vit(
+                                inp, n_pos,
+                                NORM_TYPE_NORMAL,
+                                hparams.ffn_op,
+                                model.position_embeddings,
+                                add_pos);
+
+        // remove CLS token
+        cur = ggml_view_2d(ctx0, cur,
+            n_embd, n_patches,
+            ggml_row_size(cur->type, n_embd), 0);
+
+        // pixel shuffle
+        // based on Llama4VisionPixelShuffleMLP
+        // https://github.com/huggingface/transformers/blob/2932f318a20d9e54cc7aea052e040164d85de7d6/src/transformers/models/llama4/modeling_llama4.py#L1151
+        {
+            const int scale_factor = model.hparams.proj_scale_factor;
+            const int bsz = 1; // batch size, always 1 for now since we don't support batching
+            GGML_ASSERT(scale_factor > 0);
+            GGML_ASSERT(n_patches_x == n_patches_y); // llama4 only supports square images
+            cur = ggml_reshape_4d(ctx0, cur,
+                n_embd * scale_factor,
+                n_patches_x / scale_factor,
+                n_patches_y,
+                bsz);
+            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+            cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur),
+                n_embd * scale_factor * scale_factor,
+                n_patches_x / scale_factor,
+                n_patches_y / scale_factor,
+                bsz);
+            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+            // flatten to 2D
+            cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur),
+                n_embd * scale_factor * scale_factor,
+                n_patches / scale_factor / scale_factor);
+            cb(cur, "pixel_shuffle", -1);
+        }
+
+        // based on Llama4VisionMLP2 (always uses GELU activation, no bias)
+        {
+            cur = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, cur);
+            cur = ggml_gelu(ctx0, cur);
+            cur = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, cur);
+            cur = ggml_gelu(ctx0, cur);
+            cb(cur, "adapter_mlp", -1);
+        }
+
+        // Llama4MultiModalProjector
+        cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur);
+        cb(cur, "projected", -1);
+
+        // build the graph
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+    // this graph is used by llava, granite and glm
+    // due to having embedding_stack (used by granite), we cannot reuse build_vit
+    ggml_cgraph * build_llava() {
+        const int batch_size = 1;
+        const int n_pos = n_patches + (model.class_embedding ? 1 : 0);
+
+        GGML_ASSERT(n_patches_x == n_patches_y && "only square images supported");
+
+        // Calculate the deepest feature layer based on hparams and projector type
+        int max_feature_layer = n_layer;
+        {
+            // Get the index of the second to last layer; this is the default for models that have a llava projector
+            int il_last = hparams.n_layer - 1;
+            int deepest_feature_layer = -1;
+
+            if (ctx->proj_type() == PROJECTOR_TYPE_MINICPMV || ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE) {
+                il_last += 1;
+            }
+
+            // If we set explicit vision feature layers, only go up to the deepest one
+            // NOTE: only used by granite-vision models for now
+            for (const auto & feature_layer : hparams.vision_feature_layer) {
+                if (feature_layer > deepest_feature_layer) {
+                    deepest_feature_layer = feature_layer;
+                }
+            }
+            max_feature_layer = deepest_feature_layer < 0 ? il_last : deepest_feature_layer;
+        }
+
+        ggml_tensor * inp = build_inp();
+
+        // concat class_embeddings and patch_embeddings
+        if (model.class_embedding) {
+            inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
+        }
+
+        ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
+        ggml_set_name(positions, "positions");
+        ggml_set_input(positions);
+
+        inp = ggml_add(ctx0, inp, ggml_get_rows(ctx0, model.position_embeddings, positions));
+
+        ggml_tensor * inpL = inp;
+
+        // pre-layernorm
+        if (model.pre_ln_w) {
+            inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, NORM_TYPE_NORMAL, eps, -1);
+            cb(inpL, "pre_ln", -1);
+        }
+
+        std::vector embedding_stack;
+        const auto & vision_feature_layer = hparams.vision_feature_layer;
+
+        // loop over layers
+        for (int il = 0; il < max_feature_layer; il++) {
+            auto & layer = model.layers[il];
+            ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
+
+            // If this is an embedding feature layer, save the output.
+            // NOTE: 0 index here refers to the input to the encoder.
+            if (vision_feature_layer.find(il) != vision_feature_layer.end()) {
+                embedding_stack.push_back(cur);
+            }
+
+            // layernorm1
+            cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il);
+            cb(cur, "layer_inp_normed", il);
+
+            // self-attention
+            {
+                ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
+                if (layer.q_b) {
+                    Qcur = ggml_add(ctx0, Qcur, layer.q_b);
+                }
+
+                ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
+                if (layer.k_b) {
+                    Kcur = ggml_add(ctx0, Kcur, layer.k_b);
+                }
+
+                ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
+                if (layer.v_b) {
+                    Vcur = ggml_add(ctx0, Vcur, layer.v_b);
+                }
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos);
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(layer.o_w, layer.o_b,
+                    Qcur, Kcur, Vcur, nullptr, kq_scale, il);
+                cb(cur, "attn_out", il);
+            }
+
+            // re-add the layer input, e.g., residual
+            cur = ggml_add(ctx0, cur, inpL);
+
+            inpL = cur; // inpL = residual, cur = hidden_states
+
+            cb(cur, "ffn_inp", il);
+
+            // layernorm2
+            cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il);
+            cb(cur, "ffn_inp_normed", il);
+
+            // ffn
+            cur = build_ffn(cur,
+                layer.ff_up_w, layer.ff_up_b,
+                layer.ff_gate_w, layer.ff_gate_b,
+                layer.ff_down_w, layer.ff_down_b,
+                hparams.ffn_op, il);
+
+            cb(cur, "ffn_out", il);
+
+            // residual 2
+            cur = ggml_add(ctx0, inpL, cur);
+            cb(cur, "layer_out", il);
+
+            inpL = cur;
+        }
+
+        // post-layernorm
+        if (model.post_ln_w) {
+            inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, NORM_TYPE_NORMAL, eps, -1);
+        }
+
+        ggml_tensor * embeddings = inpL;
+
+        // process vision feature layers (used by granite)
+        {
+            // final layer is a vision feature layer
+            if (vision_feature_layer.find(max_feature_layer) != vision_feature_layer.end()) {
+                embedding_stack.push_back(inpL);
+            }
+
+            // If feature layers are explicitly set, stack them (if we have multiple)
+            if (!embedding_stack.empty()) {
+                embeddings = embedding_stack[0];
+                for (size_t i = 1; i < embedding_stack.size(); i++) {
+                    embeddings = ggml_concat(ctx0, embeddings, embedding_stack[i], 0);
+                }
+            }
+        }
+
+        // llava projector (also used by granite)
+        if (ctx->model.hparams.has_llava_projector) {
+            embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
+
+            ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
+            ggml_set_name(patches, "patches");
+            ggml_set_input(patches);
+
+            // shape [1, 576, 1024]
+            // ne is whcn, ne = [1024, 576, 1, 1]
+            embeddings = ggml_get_rows(ctx0, embeddings, patches);
+
+            // print_tensor_info(embeddings, "embeddings");
+
+            // llava projector
+            if (ctx->proj_type() == PROJECTOR_TYPE_MLP) {
+                embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
+                embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
+
+                embeddings = ggml_gelu(ctx0, embeddings);
+                if (model.mm_2_w) {
+                    embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
+                    embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
+                }
+            }
+            else if (ctx->proj_type() == PROJECTOR_TYPE_MLP_NORM) {
+                embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
+                embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
+                // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false);
+                // First LayerNorm
+                embeddings = ggml_norm(ctx0, embeddings, eps);
+                embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_1_w),
+                                    model.mm_1_b);
+
+                // GELU activation
+                embeddings = ggml_gelu(ctx0, embeddings);
+
+                // Second linear layer
+                embeddings = ggml_mul_mat(ctx0, model.mm_3_w, embeddings);
+                embeddings = ggml_add(ctx0, embeddings, model.mm_3_b);
+
+                // Second LayerNorm
+                embeddings = ggml_norm(ctx0, embeddings, eps);
+                embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_4_w),
+                                    model.mm_4_b);
+            }
+            else if (ctx->proj_type() == PROJECTOR_TYPE_LDP) {
+                // MobileVLM projector
+                int n_patch = 24;
+                ggml_tensor * mlp_1 = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, embeddings);
+                mlp_1 = ggml_add(ctx0, mlp_1, model.mm_model_mlp_1_b);
+                mlp_1 = ggml_gelu(ctx0, mlp_1);
+                ggml_tensor * mlp_3 = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, mlp_1);
+                mlp_3 = ggml_add(ctx0, mlp_3, model.mm_model_mlp_3_b);
+                // mlp_3 shape = [1, 576, 2048], ne = [2048, 576, 1, 1]
+
+                // block 1
+                ggml_tensor * block_1 = nullptr;
+                {
+                    // transpose from [1, 576, 2048] --> [1, 2048, 576] --> [1, 2048, 24, 24]
+                    mlp_3 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_3, 1, 0, 2, 3));
+                    mlp_3 = ggml_reshape_4d(ctx0, mlp_3, n_patch, n_patch, mlp_3->ne[1], mlp_3->ne[2]);
+                    // stride = 1, padding = 1, bias is nullptr
+                    block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_1_block_0_0_w, mlp_3, 1, 1, 1, 1, 1, 1);
+
+                    // layer norm
+                    // // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
+                    block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 2, 0, 3));
+                    // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1]
+                    block_1 = ggml_norm(ctx0, block_1, eps);
+                    block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_1_block_0_1_w), model.mm_model_block_1_block_0_1_b);
+                    block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
+
+                    // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
+                    // hardswish
+                    ggml_tensor * block_1_hw = ggml_hardswish(ctx0, block_1);
+
+                    block_1 = ggml_pool_2d(ctx0, block_1_hw, GGML_OP_POOL_AVG, block_1_hw->ne[0], block_1_hw->ne[1], block_1_hw->ne[0], block_1_hw->ne[1], 0, 0);
+                    // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
+                    // pointwise conv
+                    block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
+                    block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc1_w, block_1);
+                    block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc1_b);
+                    block_1 = ggml_relu(ctx0, block_1);
+                    block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc2_w, block_1);
+                    block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc2_b);
+                    block_1 = ggml_hardsigmoid(ctx0, block_1);
+                    // block_1_hw shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1], block_1 shape = [1, 2048], ne = [2048, 1, 1, 1]
+                    block_1 = ggml_reshape_4d(ctx0, block_1, 1, 1, block_1->ne[0], block_1->ne[1]);
+                    block_1 = ggml_mul(ctx0, block_1_hw, block_1);
+
+                    int w = block_1->ne[0], h = block_1->ne[1];
+                    block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]);
+                    block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
+
+                    // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
+                    block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_2_0_w, block_1);
+                    block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
+
+                    // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1]
+                    block_1 = ggml_norm(ctx0, block_1, eps);
+                    block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_1_block_2_1_w), model.mm_model_block_1_block_2_1_b);
+                    block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
+                    // block1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
+                    // residual
+                    block_1 = ggml_add(ctx0, mlp_3, block_1);
+                }
+
+                // block_2
+                {
+                    // stride = 2
+                    block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_2_block_0_0_w, block_1, 2, 2, 1, 1, 1, 1);
+
+                    // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1]
+                    // layer norm
+                    block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 2, 0, 3));
+                    // block_1 shape = [1, 12, 12, 2048], ne = [2048, 12, 12, 1]
+                    block_1 = ggml_norm(ctx0, block_1, eps);
+                    block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_2_block_0_1_w), model.mm_model_block_2_block_0_1_b);
+                    block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
+                    // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1]
+                    // hardswish
+                    ggml_tensor * block_1_hw = ggml_hardswish(ctx0, block_1);
+
+                    // not sure the parameters is right for globalAvgPooling
+                    block_1 = ggml_pool_2d(ctx0, block_1_hw, GGML_OP_POOL_AVG, block_1_hw->ne[0], block_1_hw->ne[1], block_1_hw->ne[0], block_1_hw->ne[1], 0, 0);
+                    // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
+                    // pointwise conv
+                    block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
+                    block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc1_w, block_1);
+                    block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc1_b);
+                    block_1 = ggml_relu(ctx0, block_1);
+                    block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc2_w, block_1);
+                    block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc2_b);
+                    block_1 = ggml_hardsigmoid(ctx0, block_1);
+
+                    // block_1_hw shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1], block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
+                    block_1 = ggml_reshape_4d(ctx0, block_1, 1, 1, block_1->ne[0], block_1->ne[1]);
+                    block_1 = ggml_mul(ctx0, block_1_hw, block_1);
+
+                    int w = block_1->ne[0], h = block_1->ne[1];
+                    block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]);
+                    block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
+                    // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
+                    block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_2_0_w, block_1);
+                    block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
+
+
+                    // block_1 shape = [1, 12, 12, 2048], ne = [2048, 12, 12, 1]
+                    block_1 = ggml_norm(ctx0, block_1, eps);
+                    block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_2_block_2_1_w), model.mm_model_block_2_block_2_1_b);
+                    block_1 = ggml_reshape_3d(ctx0, block_1, block_1->ne[0], block_1->ne[1] * block_1->ne[2], block_1->ne[3]);
+                    // block_1 shape = [1, 144, 2048], ne = [2048, 144, 1]
+                }
+                embeddings = block_1;
+            }
+            else if (ctx->proj_type() == PROJECTOR_TYPE_LDPV2)
+            {
+                int n_patch = 24;
+                ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
+                mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b);
+                mlp_0 = ggml_gelu(ctx0, mlp_0);
+                ggml_tensor * mlp_2 = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, mlp_0);
+                mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b);
+                // mlp_2 ne = [2048, 576, 1, 1]
+                // // AVG Pool Layer 2*2, strides = 2
+                mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 0, 2, 3));
+                // mlp_2 ne = [576, 2048, 1, 1]
+                mlp_2 = ggml_reshape_4d(ctx0, mlp_2, n_patch, n_patch, mlp_2->ne[1], mlp_2->ne[2]);
+                // mlp_2 ne [24, 24, 2048, 1]
+                mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0);
+                // weight ne = [3, 3, 2048, 1]
+                ggml_tensor * peg_0 = ggml_conv_2d_dw(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1);
+                peg_0 = ggml_cont(ctx0, ggml_permute(ctx0, peg_0, 1, 2, 0, 3));
+                peg_0 = ggml_add(ctx0, peg_0, model.mm_model_peg_0_b);
+                mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 2, 0, 3));
+                peg_0 = ggml_add(ctx0, peg_0, mlp_2);
+                peg_0 = ggml_reshape_3d(ctx0, peg_0, peg_0->ne[0], peg_0->ne[1] * peg_0->ne[2], peg_0->ne[3]);
+                embeddings = peg_0;
+            }
+            else {
+                GGML_ABORT("fatal error");
+            }
+        }
+
+        // glm projector
+        else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE) {
+            size_t gridsz = (size_t)sqrt(embeddings->ne[1]);
+            embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3));
+            embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]);
+            embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1);
+            embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size);
+            embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3));
+            embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b);
+            // GLU
+            {
+                embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
+                embeddings = ggml_norm(ctx0, embeddings, eps);
+                embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
+                embeddings = ggml_gelu_inplace(ctx0, embeddings);
+                ggml_tensor * x = embeddings;
+                embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
+                x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
+                embeddings = ggml_silu_inplace(ctx0, embeddings);
+                embeddings = ggml_mul(ctx0, embeddings,x);
+                embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
+            }
+            // arrangement of BOI/EOI token embeddings
+            // note: these embeddings are not present in text model, hence we cannot process them as text tokens
+            // see: https://huggingface.co/THUDM/glm-edge-v-2b/blob/main/siglip.py#L53
+            {
+                embeddings = ggml_concat(ctx0, model.mm_glm_tok_boi, embeddings, 1); // BOI
+                embeddings = ggml_concat(ctx0, embeddings, model.mm_glm_tok_eoi, 1); // EOI
+            }
+        }
+
+        else {
+            GGML_ABORT("llava: unknown projector type");
+        }
+
+        // build the graph
+        ggml_build_forward_expand(gf, embeddings);
+
+        return gf;
+    }
+
+    // whisper encoder with custom projector
+    ggml_cgraph * build_whisper_enc() {
+        const int n_frames = img.nx;
+        const int n_pos    = n_frames / 2;
+        GGML_ASSERT(model.position_embeddings->ne[1] >= n_pos);
+
+        ggml_tensor * inp = build_inp_raw(1);
+
+        // conv1d block
+        {
+            // convolution + gelu
+            ggml_tensor * cur = ggml_conv_1d_ph(ctx0, model.conv1d_1_w, inp, 1, 1);
+            cur = ggml_add(ctx0, cur, model.conv1d_1_b);
+
+            cur = ggml_gelu_erf(ctx0, cur);
+
+            cur = ggml_conv_1d_ph(ctx0, model.conv1d_2_w, cur, 2, 1);
+            cur = ggml_add(ctx0, cur, model.conv1d_2_b);
+
+            cur = ggml_gelu_erf(ctx0, cur);
+            // transpose
+            inp = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+            cb(inp, "after_conv1d", -1);
+        }
+
+        // sanity check (only check one layer, but it should be the same for all)
+        GGML_ASSERT(model.layers[0].ln_1_w && model.layers[0].ln_1_b);
+        GGML_ASSERT(model.layers[0].ln_2_w && model.layers[0].ln_2_b);
+        GGML_ASSERT(model.layers[0].q_b);
+        GGML_ASSERT(model.layers[0].v_b);
+        GGML_ASSERT(!model.layers[0].k_b); // no bias for k
+        GGML_ASSERT(model.post_ln_w && model.post_ln_b);
+
+        ggml_tensor * pos_embd_selected = ggml_view_2d(
+            ctx0, model.position_embeddings,
+            model.position_embeddings->ne[0], n_pos,
+            model.position_embeddings->nb[1], 0
+        );
+        ggml_tensor * cur = build_vit(
+                                inp, n_pos,
+                                NORM_TYPE_NORMAL,
+                                hparams.ffn_op,
+                                pos_embd_selected,
+                                nullptr);
+
+        cb(cur, "after_transformer", -1);
+
+        if (ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX) {
+            // StackAudioFrames
+            // https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py
+            {
+                int64_t stride = n_embd * hparams.proj_stack_factor;
+                int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride);
+                int64_t pad = padded_len - ggml_nelements(cur);
+                if (pad > 0) {
+                    cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
+                    cur = ggml_pad(ctx0, cur, pad, 0, 0, 0);
+                }
+                cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride,
+                                    ggml_row_size(cur->type, stride), 0);
+            }
+
+            cb(cur, "after_stacked", -1);
+
+            // UltravoxProjector
+            {
+                // pre-norm
+                cur = ggml_rms_norm(ctx0, cur, 1e-6);
+                cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
+
+                // ffn in
+                cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
+
+                // swiglu
+                {
+                    int64_t split_point = cur->ne[0] / 2;
+                    ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
+                    ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
+
+                    // see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
+                    x1 = ggml_silu(ctx0, x1);
+                    cur = ggml_mul(ctx0, x0, x1);
+                }
+
+                // mid-norm
+                cur = ggml_rms_norm(ctx0, cur, 1e-6);
+                cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w);
+
+                // ffn out
+                cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
+            }
+
+        } else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A) {
+            // projector
+            cur = ggml_mul_mat(ctx0, model.mm_fc_w, cur);
+            cur = ggml_add(ctx0, cur, model.mm_fc_b);
+
+        } else {
+            GGML_ABORT("%s: unknown projector type", __func__);
+        }
+
+        cb(cur, "projected", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
+private:
+    //
+    // utility functions
+    //
+
+    void cb(ggml_tensor * cur0, const char * name, int il) const {
+        if (ctx->debug_graph) {
+            ggml_tensor * cur = ggml_cpy(ctx0, cur0, ggml_dup_tensor(ctx0, cur0));
+            std::string cur_name = il >= 0 ? std::string(name) + "_" + std::to_string(il) : name;
+            ggml_set_name(cur, cur_name.c_str());
+            ggml_set_output(cur);
+            ggml_build_forward_expand(gf, cur);
+            ctx->debug_print_tensors.push_back(cur);
+        }
+    }
+
+    // build vision transformer (ViT) cgraph
+    // this function should cover most of the models
+    // if your model has specific features, you should probably duplicate this function
+    ggml_tensor * build_vit(
+                ggml_tensor * inp,
+                int64_t n_pos,
+                norm_type norm_t,
+                ffn_op_type ffn_t,
+                ggml_tensor * learned_pos_embd,
+                std::function add_pos
+            ) {
+        if (learned_pos_embd) {
+            inp = ggml_add(ctx0, inp, learned_pos_embd);
+            cb(inp, "pos_embed", -1);
+        }
+
+        ggml_tensor * inpL = inp;
+
+        // pre-layernorm
+        if (model.pre_ln_w) {
+            inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1);
+            cb(inpL, "pre_ln", -1);
+        }
+
+        // loop over layers
+        for (int il = 0; il < n_layer; il++) {
+            auto & layer = model.layers[il];
+            ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
+
+            // layernorm1
+            cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il);
+            cb(cur, "layer_inp_normed", il);
+
+            // self-attention
+            {
+                ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
+                if (layer.q_b) {
+                    Qcur = ggml_add(ctx0, Qcur, layer.q_b);
+                }
+
+                ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
+                if (layer.k_b) {
+                    Kcur = ggml_add(ctx0, Kcur, layer.k_b);
+                }
+
+                ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
+                if (layer.v_b) {
+                    Vcur = ggml_add(ctx0, Vcur, layer.v_b);
+                }
+
+                if (layer.q_norm) {
+                    Qcur = build_norm(Qcur, layer.q_norm, NULL, norm_t, eps, il);
+                    cb(Qcur, "Qcur_norm", il);
+                }
+
+                if (layer.k_norm) {
+                    Kcur = build_norm(Kcur, layer.k_norm, NULL, norm_t, eps, il);
+                    cb(Kcur, "Kcur_norm", il);
+                }
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos);
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                if (add_pos) {
+                    Qcur = add_pos(Qcur, layer);
+                    Kcur = add_pos(Kcur, layer);
+                    cb(Qcur, "Qcur_pos", il);
+                    cb(Kcur, "Kcur_pos", il);
+                }
+
+                cur = build_attn(layer.o_w, layer.o_b,
+                    Qcur, Kcur, Vcur, nullptr, kq_scale, il);
+                cb(cur, "attn_out", il);
+            }
+
+            if (layer.ls_1_w) {
+                cur = ggml_mul(ctx0, cur, layer.ls_1_w);
+                cb(cur, "attn_out_scaled", il);
+            }
+
+            // re-add the layer input, e.g., residual
+            cur = ggml_add(ctx0, cur, inpL);
+
+            inpL = cur; // inpL = residual, cur = hidden_states
+
+            cb(cur, "ffn_inp", il);
+
+            // layernorm2
+            cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il);
+            cb(cur, "ffn_inp_normed", il);
+
+            // ffn
+            cur = build_ffn(cur,
+                layer.ff_up_w, layer.ff_up_b,
+                layer.ff_gate_w, layer.ff_gate_b,
+                layer.ff_down_w, layer.ff_down_b,
+                ffn_t, il);
+
+            cb(cur, "ffn_out", il);
+
+            if (layer.ls_2_w) {
+                cur = ggml_mul(ctx0, cur, layer.ls_2_w);
+                cb(cur, "ffn_out_scaled", il);
+            }
+
+            // residual 2
+            cur = ggml_add(ctx0, inpL, cur);
+            cb(cur, "layer_out", il);
+
+            inpL = cur;
+        }
+
+        // TODO @ngxson : find a way to move this outside
+        if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A) {
+            ggml_tensor * cur = inpL;
+            cur = ggml_transpose(ctx0, cur);
+            cur = ggml_cont(ctx0, cur);
+            cur = ggml_pool_1d(ctx0, cur, GGML_OP_POOL_AVG, 2, 2, 0);
+            cur = ggml_transpose(ctx0, cur);
+            cur = ggml_cont(ctx0, cur);
+            inpL = cur;
+        }
+
+        // post-layernorm
+        if (model.post_ln_w) {
+            inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, -1);
+        }
+        return inpL;
+    }
+
+    // build the input after conv2d (inp_raw --> patches)
+    // returns tensor with shape [n_embd, n_patches]
+    ggml_tensor * build_inp() {
+        ggml_tensor * inp_raw = build_inp_raw();
+        ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
+        inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd);
+        inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
+        if (model.patch_bias) {
+            inp = ggml_add(ctx0, inp, model.patch_bias);
+            cb(inp, "patch_bias", -1);
+        }
+        return inp;
+    }
+
+    ggml_tensor * build_inp_raw(int channels = 3) {
+        ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, channels);
+        ggml_set_name(inp_raw, "inp_raw");
+        ggml_set_input(inp_raw);
+        return inp_raw;
+    }
+
+    ggml_tensor * build_norm(
+            ggml_tensor * cur,
+            ggml_tensor * mw,
+            ggml_tensor * mb,
+            norm_type type,
+            float norm_eps,
+            int il) const {
+
+        cur = type == NORM_TYPE_RMS
+            ? ggml_rms_norm(ctx0, cur, norm_eps)
+            : ggml_norm(ctx0, cur, norm_eps);
+
+        if (mw || mb) {
+            cb(cur, "norm", il);
+        }
+
+        if (mw) {
+            cur = ggml_mul(ctx0, cur, mw);
+            if (mb) {
+                cb(cur, "norm_w", il);
+            }
+        }
+
+        if (mb) {
+            cur = ggml_add(ctx0, cur, mb);
+        }
+
+        return cur;
+    }
+
+    ggml_tensor * build_ffn(
+            ggml_tensor * cur,
+            ggml_tensor * up,
+            ggml_tensor * up_b,
+            ggml_tensor * gate,
+            ggml_tensor * gate_b,
+            ggml_tensor * down,
+            ggml_tensor * down_b,
+            ffn_op_type type_op,
+            int il) const {
+
+        ggml_tensor * tmp = up ? ggml_mul_mat(ctx0, up, cur) : cur;
+        cb(tmp, "ffn_up", il);
+
+        if (up_b) {
+            tmp = ggml_add(ctx0, tmp, up_b);
+            cb(tmp, "ffn_up_b", il);
+        }
+
+        if (gate) {
+            cur = ggml_mul_mat(ctx0, gate, cur);
+            cb(cur, "ffn_gate", il);
+
+            if (gate_b) {
+                cur = ggml_add(ctx0, cur, gate_b);
+                cb(cur, "ffn_gate_b", il);
+            }
+        } else {
+            cur = tmp;
+        }
+
+        switch (type_op) {
+            case FFN_SILU:
+                {
+                    cur = ggml_silu(ctx0, cur);
+                    cb(cur, "ffn_silu", il);
+                } break;
+            case FFN_GELU:
+                {
+                    cur = ggml_gelu(ctx0, cur);
+                    cb(cur, "ffn_gelu", il);
+                } break;
+            case FFN_GELU_ERF:
+                {
+                    cur = ggml_gelu_erf(ctx0, cur);
+                    cb(cur, "ggml_gelu_erf", il);
+                } break;
+            case FFN_GELU_QUICK:
+                {
+                    cur = ggml_gelu_quick(ctx0, cur);
+                    cb(cur, "ffn_relu", il);
+                } break;
+        }
+
+        // we only support parallel ffn for now
+        if (gate) {
+            cur = ggml_mul(ctx0, cur, tmp);
+            cb(cur, "ffn_gate_par", il);
+        }
+
+        if (down) {
+            cur = ggml_mul_mat(ctx0, down, cur);
+        }
+
+        if (down_b) {
+            cb(cur, "ffn_down", il);
+        }
+
+        if (down_b) {
+            cur = ggml_add(ctx0, cur, down_b);
+        }
+
+        return cur;
+    }
+
+    ggml_tensor * build_attn(
+            ggml_tensor * wo,
+            ggml_tensor * wo_b,
+            ggml_tensor * q_cur,
+            ggml_tensor * k_cur,
+            ggml_tensor * v_cur,
+            ggml_tensor * kq_mask,
+            float kq_scale,
+            int il) const {
+        // these nodes are added to the graph together so that they are not reordered
+        // by doing so, the number of splits in the graph is reduced
+        ggml_build_forward_expand(gf, q_cur);
+        ggml_build_forward_expand(gf, k_cur);
+        ggml_build_forward_expand(gf, v_cur);
+
+        ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
+        //cb(q, "q", il);
+
+        ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
+        //cb(k, "k", il);
+
+        ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
+        v = ggml_cont(ctx0, v);
+        //cb(k, "v", il);
+
+        ggml_tensor * cur;
+
+        // TODO @ngxson : support flash attention
+        {
+            const auto n_tokens = q->ne[1];
+            const auto n_head   = q->ne[2];
+            // const auto n_kv     = k->ne[1]; // for flash attention
+
+            ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+            // F32 may not needed for vision encoders?
+            // ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
+
+            kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f);
+
+            ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
+            cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+            cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
+        }
+
+        cb(cur, "kqv_out", il);
+
+        if (wo) {
+            cur = ggml_mul_mat(ctx0, wo, cur);
+        }
+
+        if (wo_b) {
+            cur = ggml_add(ctx0, cur, wo_b);
+        }
+
+        return cur;
+    }
+
+    // implementation of the 2D RoPE without adding a new op in ggml
+    // this is not efficient (use double the memory), but works on all backends
+    // TODO: there was a more efficient which relies on ggml_view and ggml_rope_ext_inplace, but the rope inplace does not work well with non-contiguous tensors ; we should fix that and revert back to the original implementation in https://github.com/ggml-org/llama.cpp/pull/13065
+    static ggml_tensor * build_rope_2d(
+        ggml_context * ctx0,
+        ggml_tensor * cur,
+        ggml_tensor * pos_a, // first half
+        ggml_tensor * pos_b, // second half
+        const float freq_base,
+        const bool interleave_freq
+    ) {
+        const int64_t n_dim  = cur->ne[0];
+        const int64_t n_head = cur->ne[1];
+        const int64_t n_pos  = cur->ne[2];
+
+        // for example, if we have cur tensor of shape (n_dim=8, n_head, n_pos)
+        // we will have a list of 4 inv_freq: 1e-0, 1e-1, 1e-2, 1e-3
+        // first half of cur will use 1e-0, 1e-2 (even)
+        // second half of cur will use 1e-1, 1e-3 (odd)
+        // the trick here is to rotate just half of n_dim, so inv_freq will automatically be even
+        //  ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
+        // then for the second half, we use freq_scale to shift the inv_freq
+        //  ^ why? replace (2i) with (2i+1) in the above equation
+        const float freq_scale_odd = interleave_freq
+                                    ? std::pow(freq_base, (float)-2/n_dim)
+                                    : 1.0;
+
+        // first half
+        ggml_tensor * first;
+        {
+            first = ggml_view_3d(ctx0, cur,
+                n_dim/2, n_head, n_pos,
+                ggml_row_size(cur->type, n_dim),
+                ggml_row_size(cur->type, n_dim*n_head),
+                0);
+            first = ggml_rope_ext(
+                ctx0,
+                first,
+                pos_a,      // positions
+                nullptr,    // freq factors
+                n_dim/2,    // n_dims
+                0, 0, freq_base,
+                1.0f, 0.0f, 1.0f, 0.0f, 0.0f
+            );
+        }
+
+        // second half
+        ggml_tensor * second;
+        {
+            second = ggml_view_3d(ctx0, cur,
+                n_dim/2, n_head, n_pos,
+                ggml_row_size(cur->type, n_dim),
+                ggml_row_size(cur->type, n_dim*n_head),
+                n_dim/2 * ggml_element_size(cur));
+            second = ggml_cont(ctx0, second); // copy, because ggml_rope don't play well with non-contiguous tensors
+            second = ggml_rope_ext(
+                ctx0,
+                second,
+                pos_b,      // positions
+                nullptr,    // freq factors
+                n_dim/2,    // n_dims
+                0, 0, freq_base,
+                freq_scale_odd,
+                0.0f, 1.0f, 0.0f, 0.0f
+            );
+        }
+
+        cur = ggml_concat(ctx0, first, second, 0);
+        return cur;
+    }
+
+};
+
+static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
+    GGML_ASSERT(imgs.entries.size() == 1 && "n_batch > 1 is not supported");
+    clip_graph graph(ctx, *imgs.entries[0]);
+
+    ggml_cgraph * res;
+
+    switch (ctx->proj_type()) {
+        case PROJECTOR_TYPE_GEMMA3:
+        case PROJECTOR_TYPE_IDEFICS3:
+            {
+                res = graph.build_siglip();
+            } break;
+        case PROJECTOR_TYPE_PIXTRAL:
+            {
+                res = graph.build_pixtral();
+            } break;
+        case PROJECTOR_TYPE_QWEN2VL:
+        case PROJECTOR_TYPE_QWEN25VL:
+            {
+                res = graph.build_qwen2vl();
+            } break;
+        case PROJECTOR_TYPE_MINICPMV:
+            {
+                res = graph.build_minicpmv();
+            } break;
+        case PROJECTOR_TYPE_INTERNVL:
+            {
+                res = graph.build_internvl();
+            } break;
+        case PROJECTOR_TYPE_LLAMA4:
+            {
+                res = graph.build_llama4();
+            } break;
+        case PROJECTOR_TYPE_ULTRAVOX:
+        case PROJECTOR_TYPE_QWEN2A:
+            {
+                res = graph.build_whisper_enc();
+            } break;
+        default:
+            {
+                res = graph.build_llava();
+            } break;
+    }
+    return res;
+}
+
+struct clip_model_loader {
+    ggml_context_ptr ctx_meta;
+    gguf_context_ptr ctx_gguf;
+
+    std::string fname;
+
+    size_t model_size = 0; // in bytes
+
+    bool has_vision = false;
+    bool has_audio  = false;
+
+    // TODO @ngxson : we should not pass clip_ctx here, it should be clip_model
+    clip_model_loader(const char * fname) : fname(fname) {
+        struct ggml_context * meta = nullptr;
+
+        struct gguf_init_params params = {
+            /*.no_alloc = */ true,
+            /*.ctx      = */ &meta,
+        };
+
+        ctx_gguf = gguf_context_ptr(gguf_init_from_file(fname, params));
+        if (!ctx_gguf.get()) {
+            throw std::runtime_error(string_format("%s: failed to load CLIP model from %s. Does this file exist?\n", __func__, fname));
+        }
+
+        ctx_meta.reset(meta);
+
+        const int n_tensors = gguf_get_n_tensors(ctx_gguf.get());
+
+        // print gguf info
+        {
+            std::string name;
+            get_string(KEY_NAME, name, false);
+            std::string description;
+            get_string(KEY_DESCRIPTION, description, false);
+            LOG_INF("%s: model name:   %s\n",  __func__, name.c_str());
+            LOG_INF("%s: description:  %s\n",  __func__, description.c_str());
+            LOG_INF("%s: GGUF version: %d\n",  __func__, gguf_get_version(ctx_gguf.get()));
+            LOG_INF("%s: alignment:    %zu\n", __func__, gguf_get_alignment(ctx_gguf.get()));
+            LOG_INF("%s: n_tensors:    %d\n",  __func__, n_tensors);
+            LOG_INF("%s: n_kv:         %d\n",  __func__, (int)gguf_get_n_kv(ctx_gguf.get()));
+            LOG_INF("\n");
+        }
+
+        // modalities
+        {
+            get_bool(KEY_HAS_VISION_ENC, has_vision, false);
+            get_bool(KEY_HAS_AUDIO_ENC,  has_audio,  false);
+
+            if (has_vision) {
+                LOG_INF("%s: has vision encoder\n", __func__);
+            }
+            if (has_audio) {
+                LOG_INF("%s: has audio encoder\n", __func__);
+            }
+        }
+
+        // tensors
+        {
+            for (int i = 0; i < n_tensors; ++i) {
+                const char * name = gguf_get_tensor_name(ctx_gguf.get(), i);
+                const size_t offset = gguf_get_tensor_offset(ctx_gguf.get(), i);
+                enum ggml_type type = gguf_get_tensor_type(ctx_gguf.get(), i);
+                ggml_tensor * cur = ggml_get_tensor(meta, name);
+                size_t tensor_size = ggml_nbytes(cur);
+                model_size += tensor_size;
+                LOG_DBG("%s: tensor[%d]: n_dims = %d, name = %s, tensor_size=%zu, offset=%zu, shape:[%" PRIu64 ", %" PRIu64 ", %" PRIu64 ", %" PRIu64 "], type = %s\n",
+                    __func__, i, ggml_n_dims(cur), cur->name, tensor_size, offset, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3], ggml_type_name(type));
+            }
+        }
+    }
+
+    void load_hparams(clip_model & model, clip_modality modality) {
+        auto & hparams = model.hparams;
+        std::string log_ffn_op; // for logging
+
+        // sanity check
+        if (modality == CLIP_MODALITY_VISION) {
+            GGML_ASSERT(has_vision);
+        } else if (modality == CLIP_MODALITY_AUDIO) {
+            GGML_ASSERT(has_audio);
+        }
+        model.modality = modality;
+
+
+        // projector type
+        std::string proj_type;
+        {
+            get_string(KEY_PROJ_TYPE, proj_type, false);
+            if (!proj_type.empty()) {
+                model.proj_type = clip_projector_type_from_string(proj_type);
+            }
+            if (model.proj_type == PROJECTOR_TYPE_UNKNOWN) {
+                throw std::runtime_error(string_format("%s: unknown projector type: %s\n", __func__, proj_type.c_str()));
+            }
+
+            // correct arch for multimodal models
+            if (model.proj_type == PROJECTOR_TYPE_QWEN25O) {
+                model.proj_type = modality == CLIP_MODALITY_VISION
+                                    ? PROJECTOR_TYPE_QWEN25VL
+                                    : PROJECTOR_TYPE_QWEN2A;
+            }
+        }
+
+        const bool is_vision = model.modality == CLIP_MODALITY_VISION;
+        const bool is_audio  = model.modality == CLIP_MODALITY_AUDIO;
+
+        // other hparams
+        {
+            const char * prefix = is_vision ? "vision" : "audio";
+            get_u32(string_format(KEY_N_EMBD,         prefix), hparams.n_embd);
+            get_u32(string_format(KEY_N_HEAD,         prefix), hparams.n_head);
+            get_u32(string_format(KEY_N_FF,           prefix), hparams.n_ff);
+            get_u32(string_format(KEY_N_BLOCK,        prefix), hparams.n_layer);
+            get_u32(string_format(KEY_PROJ_DIM,       prefix), hparams.projection_dim);
+            get_f32(string_format(KEY_LAYER_NORM_EPS, prefix), hparams.eps);
+
+            if (is_vision) {
+                get_u32(KEY_IMAGE_SIZE, hparams.image_size);
+                get_u32(KEY_PATCH_SIZE, hparams.patch_size);
+                get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
+                get_i32(KEY_MINICPMV_VERSION, hparams.minicpmv_version, false); // legacy
+
+            } else if (is_audio) {
+                get_u32(KEY_A_NUM_MEL_BINS, hparams.n_mel_bins);
+
+            } else {
+                GGML_ASSERT(false && "unknown modality");
+            }
+
+            // for pinpoints, we need to convert it into a list of resolution candidates
+            {
+                std::vector pinpoints;
+                get_arr_int(KEY_IMAGE_GRID_PINPOINTS, pinpoints, false);
+                if (!pinpoints.empty()) {
+                    for (size_t i = 0; i < pinpoints.size(); i += 2) {
+                        hparams.image_res_candidates.push_back({
+                            pinpoints[i],
+                            pinpoints[i+1],
+                        });
+                    }
+                }
+            }
+
+            // default warmup value
+            hparams.warmup_image_size = hparams.image_size;
+
+            hparams.has_llava_projector = model.proj_type == PROJECTOR_TYPE_MLP
+                                       || model.proj_type == PROJECTOR_TYPE_MLP_NORM
+                                       || model.proj_type == PROJECTOR_TYPE_LDP
+                                       || model.proj_type == PROJECTOR_TYPE_LDPV2;
+
+            {
+                bool use_gelu = false;
+                bool use_silu = false;
+                get_bool(KEY_USE_GELU, use_gelu, false);
+                get_bool(KEY_USE_SILU, use_silu, false);
+                if (use_gelu && use_silu) {
+                    throw std::runtime_error(string_format("%s: both use_gelu and use_silu are set to true\n", __func__));
+                }
+                if (use_gelu) {
+                    hparams.ffn_op = FFN_GELU;
+                    log_ffn_op = "gelu";
+                } else if (use_silu) {
+                    hparams.ffn_op = FFN_SILU;
+                    log_ffn_op = "silu";
+                } else {
+                    hparams.ffn_op = FFN_GELU_QUICK;
+                    log_ffn_op = "gelu_quick";
+                }
+            }
+
+            {
+                std::string mm_patch_merge_type;
+                get_string(KEY_MM_PATCH_MERGE_TYPE, mm_patch_merge_type, false);
+                if (mm_patch_merge_type == "spatial_unpad") {
+                    hparams.mm_patch_merge_type = PATCH_MERGE_SPATIAL_UNPAD;
+                }
+            }
+
+            if (is_vision) {
+                int idx_mean = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_MEAN);
+                int idx_std  = gguf_find_key(ctx_gguf.get(), KEY_IMAGE_STD);
+                GGML_ASSERT(idx_mean >= 0 && "image_mean not found");
+                GGML_ASSERT(idx_std >= 0  && "image_std not found");
+                const float * mean_data = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_mean);
+                const float * std_data  = (const float *) gguf_get_arr_data(ctx_gguf.get(), idx_std);
+                for (int i = 0; i < 3; ++i) {
+                    hparams.image_mean[i] = mean_data[i];
+                    hparams.image_std[i]  = std_data[i];
+                }
+            }
+
+            // Load the vision feature layer indices if they are explicitly provided;
+            // if multiple vision feature layers are present, the values will be concatenated
+            // to form the final visual features.
+            // NOTE: gguf conversions should standardize the values of the vision feature layer to
+            // be non-negative, since we use -1 to mark values as unset here.
+            std::vector vision_feature_layer;
+            get_arr_int(KEY_FEATURE_LAYER, vision_feature_layer, false);
+            // convert std::vector to std::unordered_set
+            for (auto & layer : vision_feature_layer) {
+                hparams.vision_feature_layer.insert(layer);
+            }
+
+            // model-specific params
+            switch (model.proj_type) {
+                case PROJECTOR_TYPE_MINICPMV:
+                    {
+                        if (hparams.minicpmv_version == 0) {
+                            hparams.minicpmv_version = 2; // default to 2 if not set
+                        }
+                    } break;
+                case PROJECTOR_TYPE_IDEFICS3:
+                case PROJECTOR_TYPE_INTERNVL:
+                    {
+                        get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
+                    } break;
+                case PROJECTOR_TYPE_PIXTRAL:
+                    {
+                        hparams.rope_theta = 10000.0f;
+                        hparams.warmup_image_size = hparams.patch_size * 8;
+                        get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
+                    } break;
+                case PROJECTOR_TYPE_GEMMA3:
+                    {
+                        // default value (used by all model sizes in gemma 3 family)
+                        // number of patches for each **side** is reduced by a factor of 4
+                        hparams.proj_scale_factor = 4;
+                        // test model (tinygemma3) has a different value, we optionally read it
+                        get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
+                    } break;
+                case PROJECTOR_TYPE_QWEN2VL:
+                    {
+                        // max image size = sqrt(max_pixels) = 3584
+                        // ref: https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/blob/main/preprocessor_config.json
+                        // however, the model use unreasonable memory past 1024 size, we force it to 1024 otherwise it's unusable
+                        // ref: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct/discussions/10
+                        hparams.image_size = 1024;
+                        hparams.warmup_image_size = hparams.patch_size * 8;
+                    } break;
+                case PROJECTOR_TYPE_QWEN25VL:
+                    {
+                        // max image size = sqrt(max_pixels)
+                        // https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct/blob/main/preprocessor_config.json
+                        // however, the model use unreasonable memory past 1024 size, we force it to 1024 otherwise it's unusable
+                        // ref: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct/discussions/10
+                        hparams.image_size = 1024;
+                        hparams.warmup_image_size = hparams.patch_size * 8;
+                        get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern);
+                    } break;
+                case PROJECTOR_TYPE_LLAMA4:
+                    {
+                        hparams.rope_theta = 10000.0f;
+                        get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor);
+                        set_llava_uhd_res_candidates(model, 3);
+                    } break;
+                case PROJECTOR_TYPE_ULTRAVOX:
+                case PROJECTOR_TYPE_QWEN2A:
+                    {
+                        bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX;
+                        get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor, require_stack);
+                        if (hparams.n_mel_bins != 128) {
+                            throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__));
+                        }
+                        hparams.ffn_op = FFN_GELU_ERF;
+                        log_ffn_op = "gelu_erf"; // temporary solution for logging
+                    } break;
+                default:
+                    break;
+            }
+
+            LOG_INF("%s: projector:          %s\n", __func__, proj_type.c_str());
+            LOG_INF("%s: n_embd:             %d\n", __func__, hparams.n_embd);
+            LOG_INF("%s: n_head:             %d\n", __func__, hparams.n_head);
+            LOG_INF("%s: n_ff:               %d\n", __func__, hparams.n_ff);
+            LOG_INF("%s: n_layer:            %d\n", __func__, hparams.n_layer);
+            LOG_INF("%s: ffn_op:             %s\n", __func__, log_ffn_op.c_str());
+            LOG_INF("%s: projection_dim:     %d\n", __func__, hparams.projection_dim);
+            if (is_vision) {
+                LOG_INF("\n--- vision hparams ---\n");
+                LOG_INF("%s: image_size:         %d\n", __func__, hparams.image_size);
+                LOG_INF("%s: patch_size:         %d\n", __func__, hparams.patch_size);
+                LOG_INF("%s: has_llava_proj:     %d\n", __func__, hparams.has_llava_projector);
+                LOG_INF("%s: minicpmv_version:   %d\n", __func__, hparams.minicpmv_version);
+                LOG_INF("%s: proj_scale_factor:  %d\n", __func__, hparams.proj_scale_factor);
+                LOG_INF("%s: n_wa_pattern:       %d\n", __func__, hparams.n_wa_pattern);
+            } else if (is_audio) {
+                LOG_INF("\n--- audio hparams ---\n");
+                LOG_INF("%s: n_mel_bins:         %d\n", __func__, hparams.n_mel_bins);
+                LOG_INF("%s: proj_stack_factor:  %d\n", __func__, hparams.proj_stack_factor);
+            }
+            LOG_INF("\n");
+            LOG_INF("%s: model size:         %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0);
+            LOG_INF("%s: metadata size:      %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0);
+        }
+    }
+
+    void load_tensors(clip_ctx & ctx_clip) {
+        auto & model = ctx_clip.model;
+        auto & hparams = model.hparams;
+        std::map tensor_offset;
+        std::vector tensors_to_load;
+
+        // TODO @ngxson : support both audio and video in the future
+        const char * prefix = model.modality == CLIP_MODALITY_AUDIO ? "a" : "v";
+
+        // get offsets
+        for (int64_t i = 0; i < gguf_get_n_tensors(ctx_gguf.get()); ++i) {
+            const char * name = gguf_get_tensor_name(ctx_gguf.get(), i);
+            tensor_offset[name] = gguf_get_data_offset(ctx_gguf.get()) + gguf_get_tensor_offset(ctx_gguf.get(), i);
+        }
+
+        // create data context
+        struct ggml_init_params params = {
+            /*.mem_size =*/ (gguf_get_n_tensors(ctx_gguf.get()) + 1) * ggml_tensor_overhead(),
+            /*.mem_buffer =*/ NULL,
+            /*.no_alloc =*/ true,
+        };
+        ctx_clip.ctx_data.reset(ggml_init(params));
+        if (!ctx_clip.ctx_data) {
+            throw std::runtime_error(string_format("%s: failed to init ggml context\n", __func__));
+        }
+
+        // helper function
+        auto get_tensor = [&](const std::string & name, bool required = true) {
+            ggml_tensor * cur = ggml_get_tensor(ctx_meta.get(), name.c_str());
+            if (!cur && required) {
+                throw std::runtime_error(string_format("%s: unable to find tensor %s\n", __func__, name.c_str()));
+            }
+            if (cur) {
+                tensors_to_load.push_back(cur);
+                // add tensors to context
+                ggml_tensor * data_tensor = ggml_dup_tensor(ctx_clip.ctx_data.get(), cur);
+                ggml_set_name(data_tensor, cur->name);
+                cur = data_tensor;
+            }
+            return cur;
+        };
+
+        model.class_embedding = get_tensor(TN_CLASS_EMBD, false);
+
+        model.pre_ln_w = get_tensor(string_format(TN_LN_PRE, prefix, "weight"), false);
+        model.pre_ln_b = get_tensor(string_format(TN_LN_PRE, prefix, "bias"),   false);
+
+        model.post_ln_w = get_tensor(string_format(TN_LN_POST, prefix, "weight"), false);
+        model.post_ln_b = get_tensor(string_format(TN_LN_POST, prefix, "bias"),   false);
+
+        model.patch_bias = get_tensor(TN_PATCH_BIAS, false);
+        model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD,   false);
+        model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false);
+
+        model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, prefix), false);
+
+        // layers
+        model.layers.resize(hparams.n_layer);
+        for (int il = 0; il < hparams.n_layer; ++il) {
+            auto & layer = model.layers[il];
+            layer.k_w    = get_tensor(string_format(TN_ATTN_K,      prefix, il, "weight"));
+            layer.q_w    = get_tensor(string_format(TN_ATTN_Q,      prefix, il, "weight"));
+            layer.v_w    = get_tensor(string_format(TN_ATTN_V,      prefix, il, "weight"));
+            layer.o_w    = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "weight"));
+            layer.k_norm = get_tensor(string_format(TN_ATTN_K_NORM, prefix, il, "weight"), false);
+            layer.q_norm = get_tensor(string_format(TN_ATTN_Q_NORM, prefix, il, "weight"), false);
+            layer.ln_1_w = get_tensor(string_format(TN_LN_1,        prefix, il, "weight"), false);
+            layer.ln_2_w = get_tensor(string_format(TN_LN_2,        prefix, il, "weight"), false);
+            layer.ls_1_w = get_tensor(string_format(TN_LS_1,        prefix, il, "weight"), false); // no bias
+            layer.ls_2_w = get_tensor(string_format(TN_LS_2,        prefix, il, "weight"), false); // no bias
+
+            layer.k_b    = get_tensor(string_format(TN_ATTN_K,      prefix, il, "bias"), false);
+            layer.q_b    = get_tensor(string_format(TN_ATTN_Q,      prefix, il, "bias"), false);
+            layer.v_b    = get_tensor(string_format(TN_ATTN_V,      prefix, il, "bias"), false);
+            layer.o_b    = get_tensor(string_format(TN_ATTN_OUTPUT, prefix, il, "bias"), false);
+            layer.ln_1_b = get_tensor(string_format(TN_LN_1,        prefix, il, "bias"), false);
+            layer.ln_2_b = get_tensor(string_format(TN_LN_2,        prefix, il, "bias"), false);
+
+            // ffn
+            layer.ff_up_w   = get_tensor(string_format(TN_FFN_UP,   prefix, il, "weight"));
+            layer.ff_up_b   = get_tensor(string_format(TN_FFN_UP,   prefix, il, "bias"),   false);
+            layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, prefix, il, "weight"), false);
+            layer.ff_gate_b = get_tensor(string_format(TN_FFN_GATE, prefix, il, "bias"),   false);
+            layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "weight"));
+            layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, prefix, il, "bias"),   false);
+
+            // some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here
+            // note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check!
+            if (layer.ff_up_w && layer.ff_down_w && layer.ff_down_w->ne[0] == hparams.n_embd) {
+                // swap up and down weights
+                ggml_tensor * tmp = layer.ff_up_w;
+                layer.ff_up_w = layer.ff_down_w;
+                layer.ff_down_w = tmp;
+                // swap up and down biases
+                tmp = layer.ff_up_b;
+                layer.ff_up_b = layer.ff_down_b;
+                layer.ff_down_b = tmp;
+            }
+        }
+
+        switch (model.proj_type) {
+            case PROJECTOR_TYPE_MLP:
+            case PROJECTOR_TYPE_MLP_NORM:
+                {
+                    // LLaVA projection
+                    model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"), false);
+                    model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"), false);
+                    // Yi-type llava
+                    model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"), false);
+                    model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
+                    // missing in Yi-type llava
+                    model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"), false);
+                    model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
+                    // Yi-type llava
+                    model.mm_3_w = get_tensor(string_format(TN_LLAVA_PROJ, 3, "weight"), false);
+                    model.mm_3_b = get_tensor(string_format(TN_LLAVA_PROJ, 3, "bias"), false);
+                    model.mm_4_w = get_tensor(string_format(TN_LLAVA_PROJ, 4, "weight"), false);
+                    model.mm_4_b = get_tensor(string_format(TN_LLAVA_PROJ, 4, "bias"), false);
+                    if (model.mm_3_w) {
+                        // TODO: this is a hack to support Yi-type llava
+                        model.proj_type = PROJECTOR_TYPE_MLP_NORM;
+                    }
+                    model.image_newline = get_tensor(TN_IMAGE_NEWLINE, false);
+                } break;
+            case PROJECTOR_TYPE_LDP:
+                {
+                    // MobileVLM projection
+                    model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
+                    model.mm_model_mlp_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias"));
+                    model.mm_model_mlp_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
+                    model.mm_model_mlp_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias"));
+                    model.mm_model_block_1_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "0.weight"));
+                    model.mm_model_block_1_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.weight"));
+                    model.mm_model_block_1_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 0, "1.bias"));
+                    model.mm_model_block_1_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.weight"));
+                    model.mm_model_block_1_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc1.bias"));
+                    model.mm_model_block_1_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.weight"));
+                    model.mm_model_block_1_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 1, "fc2.bias"));
+                    model.mm_model_block_1_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "0.weight"));
+                    model.mm_model_block_1_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.weight"));
+                    model.mm_model_block_1_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 1, 2, "1.bias"));
+                    model.mm_model_block_2_block_0_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "0.weight"));
+                    model.mm_model_block_2_block_0_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.weight"));
+                    model.mm_model_block_2_block_0_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 0, "1.bias"));
+                    model.mm_model_block_2_block_1_fc1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.weight"));
+                    model.mm_model_block_2_block_1_fc1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc1.bias"));
+                    model.mm_model_block_2_block_1_fc2_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.weight"));
+                    model.mm_model_block_2_block_1_fc2_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 1, "fc2.bias"));
+                    model.mm_model_block_2_block_2_0_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "0.weight"));
+                    model.mm_model_block_2_block_2_1_w = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.weight"));
+                    model.mm_model_block_2_block_2_1_b = get_tensor(string_format(TN_MVLM_PROJ_BLOCK, 2, 2, "1.bias"));
+                } break;
+            case PROJECTOR_TYPE_LDPV2:
+                {
+                    // MobilVLM_V2 projection
+                    model.mm_model_mlp_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
+                    model.mm_model_mlp_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias"));
+                    model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight"));
+                    model.mm_model_mlp_2_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "bias"));
+                    model.mm_model_peg_0_w = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "weight"));
+                    model.mm_model_peg_0_b = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "bias"));
+                } break;
+            case PROJECTOR_TYPE_MINICPMV:
+                {
+                    // model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD);
+                    model.mm_model_pos_embed_k = get_tensor(TN_MINICPMV_POS_EMBD_K);
+                    model.mm_model_query = get_tensor(TN_MINICPMV_QUERY);
+                    model.mm_model_proj = get_tensor(TN_MINICPMV_PROJ);
+                    model.mm_model_kv_proj = get_tensor(TN_MINICPMV_KV_PROJ);
+                    model.mm_model_attn_q_w = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "weight"));
+                    model.mm_model_attn_k_w = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "weight"));
+                    model.mm_model_attn_v_w = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "weight"));
+                    model.mm_model_attn_q_b = get_tensor(string_format(TN_MINICPMV_ATTN, "q", "bias"));
+                    model.mm_model_attn_k_b = get_tensor(string_format(TN_MINICPMV_ATTN, "k", "bias"));
+                    model.mm_model_attn_v_b = get_tensor(string_format(TN_MINICPMV_ATTN, "v", "bias"));
+                    model.mm_model_attn_o_w = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "weight"));
+                    model.mm_model_attn_o_b = get_tensor(string_format(TN_MINICPMV_ATTN, "out", "bias"));
+                    model.mm_model_ln_q_w = get_tensor(string_format(TN_MINICPMV_LN, "q", "weight"));
+                    model.mm_model_ln_q_b = get_tensor(string_format(TN_MINICPMV_LN, "q", "bias"));
+                    model.mm_model_ln_kv_w = get_tensor(string_format(TN_MINICPMV_LN, "kv", "weight"));
+                    model.mm_model_ln_kv_b = get_tensor(string_format(TN_MINICPMV_LN, "kv", "bias"));
+                    model.mm_model_ln_post_w = get_tensor(string_format(TN_MINICPMV_LN, "post", "weight"));
+                    model.mm_model_ln_post_b = get_tensor(string_format(TN_MINICPMV_LN, "post", "bias"));
+                } break;
+            case PROJECTOR_TYPE_GLM_EDGE:
+                {
+                    model.mm_model_adapter_conv_w = get_tensor(string_format(TN_GLM_ADAPER_CONV, "weight"));
+                    model.mm_model_adapter_conv_b = get_tensor(string_format(TN_GLM_ADAPER_CONV, "bias"));
+                    model.mm_model_mlp_0_w = get_tensor(string_format(TN_GLM_ADAPTER_LINEAR, "weight"));
+                    model.mm_model_ln_q_w = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1, "weight"));
+                    model.mm_model_ln_q_b = get_tensor(string_format(TN_GLM_ADAPTER_NORM_1, "bias"));
+                    model.mm_model_mlp_1_w = get_tensor(string_format(TN_GLM_ADAPTER_D_H_2_4H, "weight"));
+                    model.mm_model_mlp_2_w = get_tensor(string_format(TN_GLM_ADAPTER_GATE, "weight"));
+                    model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H, "weight"));
+                    model.mm_glm_tok_boi = get_tensor(string_format(TN_TOK_GLM_BOI, "weight"));
+                    model.mm_glm_tok_eoi = get_tensor(string_format(TN_TOK_GLM_EOI, "weight"));
+                } break;
+            case PROJECTOR_TYPE_QWEN2VL:
+            case PROJECTOR_TYPE_QWEN25VL:
+                {
+                    model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
+                    model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
+                    model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
+                    model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
+                } break;
+            case PROJECTOR_TYPE_GEMMA3:
+                {
+                    model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ);
+                    model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N);
+                } break;
+            case PROJECTOR_TYPE_IDEFICS3:
+                {
+                    model.projection = get_tensor(TN_MM_PROJECTOR);
+                } break;
+            case PROJECTOR_TYPE_PIXTRAL:
+                {
+                    model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
+                    model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
+                    model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
+                    model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
+                    // [IMG_BREAK] token embedding
+                    model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK);
+                    // for mistral small 3.1
+                    model.mm_input_norm_w   = get_tensor(TN_MM_INP_NORM,     false);
+                    model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
+                } break;
+            case PROJECTOR_TYPE_ULTRAVOX:
+                {
+                    model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
+                    model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
+                    model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
+                    model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
+                    model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
+                    model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
+                    model.mm_norm_pre_w = get_tensor(string_format(TN_MM_NORM_PRE, "weight"));
+                    model.mm_norm_mid_w = get_tensor(string_format(TN_MM_NORM_MID, "weight"));
+                } break;
+            case PROJECTOR_TYPE_QWEN2A:
+                {
+                    model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
+                    model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
+                    model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
+                    model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
+                    model.mm_fc_w = get_tensor(string_format(TN_MM_AUDIO_FC, "weight"));
+                    model.mm_fc_b = get_tensor(string_format(TN_MM_AUDIO_FC, "bias"));
+                } break;
+            case PROJECTOR_TYPE_INTERNVL:
+                {
+                    model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
+                    model.mm_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias"));
+                    model.mm_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
+                    model.mm_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias"));
+                    model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
+                    model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias"));
+                } break;
+            case PROJECTOR_TYPE_LLAMA4:
+                {
+                    model.mm_model_proj    = get_tensor(TN_MM_PROJECTOR);
+                    model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
+                    model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight"));
+                } break;
+            default:
+                GGML_ASSERT(false && "unknown projector type");
+        }
+
+        // load data
+        {
+            std::vector read_buf;
+
+            auto fin = std::ifstream(fname, std::ios::binary);
+            if (!fin) {
+                throw std::runtime_error(string_format("%s: failed to open %s\n", __func__, fname.c_str()));
+            }
+
+            // alloc memory and offload data
+            ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend);
+            ctx_clip.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data.get(), buft));
+            ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
+            for (auto & t : tensors_to_load) {
+                ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name);
+                const size_t offset = tensor_offset[t->name];
+                fin.seekg(offset, std::ios::beg);
+                if (!fin) {
+                    throw std::runtime_error(string_format("%s: failed to seek for tensor %s\n", __func__, t->name));
+                }
+                size_t num_bytes = ggml_nbytes(cur);
+                if (ggml_backend_buft_is_host(buft)) {
+                    // for the CPU and Metal backend, we can read directly into the tensor
+                    fin.read(reinterpret_cast(cur->data), num_bytes);
+                } else {
+                    // read into a temporary buffer first, then copy to device memory
+                    read_buf.resize(num_bytes);
+                    fin.read(reinterpret_cast(read_buf.data()), num_bytes);
+                    ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
+                }
+            }
+            fin.close();
+
+            LOG_DBG("%s: loaded %zu tensors from %s\n", __func__, tensors_to_load.size(), fname.c_str());
+        }
+    }
+
+    void alloc_compute_meta(clip_ctx & ctx_clip) {
+        const auto & hparams = ctx_clip.model.hparams;
+        ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
+
+        // create a fake batch
+        clip_image_f32_batch batch;
+        clip_image_f32_ptr img(clip_image_f32_init());
+        if (ctx_clip.model.modality == CLIP_MODALITY_VISION) {
+            img->nx = hparams.warmup_image_size;
+            img->ny = hparams.warmup_image_size;
+        } else {
+            img->nx = hparams.warmup_audio_size;
+            img->ny = hparams.n_mel_bins;
+        }
+        batch.entries.push_back(std::move(img));
+
+        ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch);
+        ggml_backend_sched_reserve(ctx_clip.sched.get(), gf);
+
+        for (size_t i = 0; i < ctx_clip.backend_ptrs.size(); ++i) {
+            ggml_backend_t backend = ctx_clip.backend_ptrs[i];
+            ggml_backend_buffer_type_t buft = ctx_clip.backend_buft[i];
+            size_t size = ggml_backend_sched_get_buffer_size(ctx_clip.sched.get(), backend);
+            if (size > 1) {
+                LOG_INF("%s: %10s compute buffer size = %8.2f MiB\n", __func__,
+                        ggml_backend_buft_name(buft),
+                        size / 1024.0 / 1024.0);
+            }
+        }
+    }
+
+    void get_bool(const std::string & key, bool & output, bool required = true) {
+        const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
+        if (i < 0) {
+            if (required) throw std::runtime_error("Key not found: " + key);
+            return;
+        }
+        output = gguf_get_val_bool(ctx_gguf.get(), i);
+    }
+
+    void get_i32(const std::string & key, int & output, bool required = true) {
+        const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
+        if (i < 0) {
+            if (required) throw std::runtime_error("Key not found: " + key);
+            return;
+        }
+        output = gguf_get_val_i32(ctx_gguf.get(), i);
+    }
+
+    void get_u32(const std::string & key, int & output, bool required = true) {
+        const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
+        if (i < 0) {
+            if (required) throw std::runtime_error("Key not found: " + key);
+            return;
+        }
+        output = gguf_get_val_u32(ctx_gguf.get(), i);
+    }
+
+    void get_f32(const std::string & key, float & output, bool required = true) {
+        const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
+        if (i < 0) {
+            if (required) throw std::runtime_error("Key not found: " + key);
+            return;
+        }
+        output = gguf_get_val_f32(ctx_gguf.get(), i);
+    }
+
+    void get_string(const std::string & key, std::string & output, bool required = true) {
+        const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
+        if (i < 0) {
+            if (required) throw std::runtime_error("Key not found: " + key);
+            return;
+        }
+        output = std::string(gguf_get_val_str(ctx_gguf.get(), i));
+    }
+
+    void get_arr_int(const std::string & key, std::vector & output, bool required = true) {
+        const int i = gguf_find_key(ctx_gguf.get(), key.c_str());
+        if (i < 0) {
+            if (required) throw std::runtime_error("Key not found: " + key);
+            return;
+        }
+        int n = gguf_get_arr_n(ctx_gguf.get(), i);
+        output.resize(n);
+        const int32_t * values = (const int32_t *)gguf_get_arr_data(ctx_gguf.get(), i);
+        for (int i = 0; i < n; ++i) {
+            output[i] = values[i];
+        }
+    }
+
+    void set_llava_uhd_res_candidates(clip_model & model, const int max_patches_per_side) {
+        auto & hparams = model.hparams;
+        for (int x = 1; x <= max_patches_per_side; x++) {
+            for (int y = 1; y <= max_patches_per_side; y++) {
+                if (x == 1 && y == 1) {
+                    continue; // skip the first point
+                }
+                hparams.image_res_candidates.push_back(clip_image_size{
+                    x*hparams.image_size,
+                    y*hparams.image_size,
+                });
+            }
+        }
+    }
+};
+
+struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_params) {
+    g_logger_state.verbosity_thold = ctx_params.verbosity;
+    clip_ctx * ctx_vision = nullptr;
+    clip_ctx * ctx_audio = nullptr;
+
+    try {
+        clip_model_loader loader(fname);
+
+        if (loader.has_vision) {
+            ctx_vision = new clip_ctx(ctx_params);
+            loader.load_hparams(ctx_vision->model, CLIP_MODALITY_VISION);
+            loader.load_tensors(*ctx_vision);
+            loader.alloc_compute_meta(*ctx_vision);
+        }
+
+        if (loader.has_audio) {
+            ctx_audio = new clip_ctx(ctx_params);
+            loader.load_hparams(ctx_audio->model, CLIP_MODALITY_AUDIO);
+            loader.load_tensors(*ctx_audio);
+            loader.alloc_compute_meta(*ctx_audio);
+        }
+
+    } catch (const std::exception & e) {
+        LOG_ERR("%s: failed to load model '%s': %s\n", __func__, fname, e.what());
+        if (ctx_vision) {
+            delete ctx_vision;
+        }
+        if (ctx_audio) {
+            delete ctx_audio;
+        }
+        return {nullptr, nullptr};
+    }
+
+    return {ctx_vision, ctx_audio};
+}
+
+struct clip_image_size * clip_image_size_init() {
+    struct clip_image_size * load_image_size = new struct clip_image_size();
+    load_image_size->width = 448;
+    load_image_size->height = 448;
+    return load_image_size;
+}
+
+struct clip_image_u8 * clip_image_u8_init() {
+    return new clip_image_u8();
+}
+
+struct clip_image_f32 * clip_image_f32_init() {
+    return new clip_image_f32();
+}
+
+struct clip_image_f32_batch * clip_image_f32_batch_init() {
+    return new clip_image_f32_batch();
+}
+
+unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny) {
+    if (nx) *nx = img->nx;
+    if (ny) *ny = img->ny;
+    return img->buf.data();
+}
+
+void clip_image_size_free(struct clip_image_size * load_image_size) {
+    if (load_image_size == nullptr) {
+        return;
+    }
+    delete load_image_size;
+}
+void clip_image_u8_free(struct clip_image_u8  * img) { if (img) delete img; }
+void clip_image_f32_free(struct clip_image_f32 * img) { if (img) delete img; }
+void clip_image_u8_batch_free(struct clip_image_u8_batch * batch) { if (batch) delete batch; }
+void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { if (batch) delete batch; }
+
+size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch) {
+    return batch->entries.size();
+}
+
+size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx) {
+    if (idx < 0 || idx >= (int)batch->entries.size()) {
+        LOG_ERR("%s: invalid index %d\n", __func__, idx);
+        return 0;
+    }
+    return batch->entries[idx]->nx;
+}
+
+size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx) {
+    if (idx < 0 || idx >= (int)batch->entries.size()) {
+        LOG_ERR("%s: invalid index %d\n", __func__, idx);
+        return 0;
+    }
+    return batch->entries[idx]->ny;
+}
+
+clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx) {
+    if (idx < 0 || idx >= (int)batch->entries.size()) {
+        LOG_ERR("%s: invalid index %d\n", __func__, idx);
+        return nullptr;
+    }
+    return batch->entries[idx].get();
+}
+
+void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, clip_image_u8 * img) {
+    img->nx = nx;
+    img->ny = ny;
+    img->buf.resize(3 * nx * ny);
+    memcpy(img->buf.data(), rgb_pixels, img->buf.size());
+}
+
+// Normalize image to float32 - careful with pytorch .to(model.device, dtype=torch.float16) - this sometimes reduces precision (32>16>32), sometimes not
+static void normalize_image_u8_to_f32(const clip_image_u8 & src, clip_image_f32 & dst, const float mean[3], const float std[3]) {
+    dst.nx = src.nx;
+    dst.ny = src.ny;
+    dst.buf.resize(src.buf.size());
+
+    // TODO @ngxson : seems like this could be done more efficiently on cgraph
+    for (size_t i = 0; i < src.buf.size(); ++i) {
+        int c = i % 3; // rgb
+        dst.buf[i] = (static_cast(src.buf[i]) / 255.0f - mean[c]) / std[c];
+    }
+}
+
+// set of tools to manupulate images
+// in the future, we can have HW acceleration by allowing this struct to access 3rd party lib like imagick or opencv
+struct image_manipulation {
+    // Bilinear resize function
+    static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) {
+        dst.nx = target_width;
+        dst.ny = target_height;
+        dst.buf.resize(3 * target_width * target_height);
+
+        float x_ratio = static_cast(src.nx - 1) / target_width;
+        float y_ratio = static_cast(src.ny - 1) / target_height;
+
+        for (int y = 0; y < target_height; y++) {
+            for (int x = 0; x < target_width; x++) {
+                float px = x_ratio * x;
+                float py = y_ratio * y;
+                int x_floor = static_cast(px);
+                int y_floor = static_cast(py);
+                float x_lerp = px - x_floor;
+                float y_lerp = py - y_floor;
+
+                for (int c = 0; c < 3; c++) {
+                    float top = lerp(
+                        static_cast(src.buf[3 * (y_floor * src.nx + x_floor) + c]),
+                        static_cast(src.buf[3 * (y_floor * src.nx + (x_floor + 1)) + c]),
+                        x_lerp
+                    );
+                    float bottom = lerp(
+                        static_cast(src.buf[3 * ((y_floor + 1) * src.nx + x_floor) + c]),
+                        static_cast(src.buf[3 * ((y_floor + 1) * src.nx + (x_floor + 1)) + c]),
+                        x_lerp
+                    );
+                    dst.buf[3 * (y * target_width + x) + c] = static_cast(lerp(top, bottom, y_lerp));
+                }
+            }
+        }
+    }
+
+    // Bicubic resize function
+    // part of image will be cropped if the aspect ratio is different
+    static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) {
+        const int nx = img.nx;
+        const int ny = img.ny;
+
+        dst.nx = target_width;
+        dst.ny = target_height;
+        dst.buf.resize(3 * target_width * target_height);
+
+        float Cc;
+        float C[5];
+        float d0, d2, d3, a0, a1, a2, a3;
+        int i, j, k, jj;
+        int x, y;
+        float dx, dy;
+        float tx, ty;
+
+        tx = (float)nx / (float)target_width;
+        ty = (float)ny / (float)target_height;
+
+        // Bicubic interpolation; adapted from ViT.cpp, inspired from :
+        //    -> https://github.com/yglukhov/bicubic-interpolation-image-processing/blob/master/libimage.c#L36
+        //    -> https://en.wikipedia.org/wiki/Bicubic_interpolation
+
+        for (i = 0; i < target_height; i++) {
+            for (j = 0; j < target_width; j++) {
+                x = (int)(tx * j);
+                y = (int)(ty * i);
+
+                dx = tx * j - x;
+                dy = ty * i - y;
+
+                for (k = 0; k < 3; k++) {
+                    for (jj = 0; jj <= 3; jj++) {
+                        d0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x - 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
+                        d2 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
+                        d3 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 2, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
+                        a0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
+
+                        a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3;
+                        a2 =  1.0 / 2 * d0 +      1.0 / 2 * d2;
+                        a3 = -1.0 / 6 * d0 -      1.0 / 2 * d2 + 1.0 / 6 * d3;
+
+                        C[jj] = a0 + a1 * dx + a2 * dx * dx + a3 * dx * dx * dx;
+
+                        d0 = C[0] - C[1];
+                        d2 = C[2] - C[1];
+                        d3 = C[3] - C[1];
+                        a0 = C[1];
+                        a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3;
+                        a2 =  1.0 / 2 * d0 +      1.0 / 2 * d2;
+                        a3 = -1.0 / 6 * d0 -      1.0 / 2 * d2 + 1.0 / 6 * d3;
+                        Cc = a0 + a1 * dy + a2 * dy * dy + a3 * dy * dy * dy;
+
+                        const uint8_t Cc2 = std::min(std::max(std::round(Cc), 0.0f), 255.0f);
+                        dst.buf[(i * target_width + j) * 3 + k] = float(Cc2);
+                    }
+                }
+            }
+        }
+
+        return true;
+    }
+
+    // llava-1.6 type of resize_and_pad
+    // if the ratio is not 1:1, padding with pad_color will be applied
+    // pad_color is single channel, default is 0 (black)
+    static void resize_and_pad_image(const clip_image_u8 & image, clip_image_u8 & dst, const clip_image_size & target_resolution, std::array pad_color = {0, 0, 0}) {
+        int target_width  = target_resolution.width;
+        int target_height = target_resolution.height;
+
+        float scale_w = static_cast(target_width) / image.nx;
+        float scale_h = static_cast(target_height) / image.ny;
+
+        int new_width, new_height;
+
+        if (scale_w < scale_h) {
+            new_width  = target_width;
+            new_height = std::min(static_cast(std::ceil(image.ny * scale_w)), target_height);
+        } else {
+            new_height = target_height;
+            new_width  = std::min(static_cast(std::ceil(image.nx * scale_h)), target_width);
+        }
+
+        clip_image_u8 resized_image;
+        bicubic_resize(image, resized_image, new_width, new_height);
+
+        clip_image_u8 padded_image;
+        padded_image.nx = target_width;
+        padded_image.ny = target_height;
+        padded_image.buf.resize(3 * target_width * target_height);
+
+        // Fill the padded image with the fill color
+        for (size_t i = 0; i < padded_image.buf.size(); i += 3) {
+            padded_image.buf[i]     = pad_color[0];
+            padded_image.buf[i + 1] = pad_color[1];
+            padded_image.buf[i + 2] = pad_color[2];
+        }
+
+        // Calculate padding offsets
+        int pad_x = (target_width  - new_width)  / 2;
+        int pad_y = (target_height - new_height) / 2;
+
+        // Copy the resized image into the center of the padded buffer
+        for (int y = 0; y < new_height; ++y) {
+            for (int x = 0; x < new_width; ++x) {
+                for (int c = 0; c < 3; ++c) {
+                    padded_image.buf[3 * ((y + pad_y) * target_width + (x + pad_x)) + c] = resized_image.buf[3 * (y * new_width + x) + c];
+                }
+            }
+        }
+        dst = std::move(padded_image);
+    }
+
+    static void crop_image(const clip_image_u8 & image, clip_image_u8 & dst, int x, int y, int w, int h) {
+        dst.nx = w;
+        dst.ny = h;
+        dst.buf.resize(3 * w * h);
+
+        for (int i = 0; i < h; ++i) {
+            for (int j = 0; j < w; ++j) {
+                int src_idx = 3 * ((y + i)*image.nx + (x + j));
+                int dst_idx = 3 * (i*w + j);
+                dst.buf[dst_idx]     = image.buf[src_idx];
+                dst.buf[dst_idx + 1] = image.buf[src_idx + 1];
+                dst.buf[dst_idx + 2] = image.buf[src_idx + 2];
+            }
+        }
+    }
+
+    // calculate the size of the **resized** image, while preserving the aspect ratio
+    // the calculated size will be aligned to the nearest multiple of align_size
+    // if H or W size is larger than max_dimension, it will be resized to max_dimension
+    static clip_image_size calc_size_preserved_ratio(const clip_image_size & inp_size, const int align_size, const int max_dimension) {
+        if (inp_size.width <= 0 || inp_size.height <= 0 || align_size <= 0 || max_dimension <= 0) {
+            return {0, 0};
+        }
+
+        float scale = std::min(1.0f, std::min(static_cast(max_dimension) / inp_size.width,
+                                              static_cast(max_dimension) / inp_size.height));
+
+        float target_width_f  = static_cast(inp_size.width)  * scale;
+        float target_height_f = static_cast(inp_size.height) * scale;
+
+        int aligned_width  = CLIP_ALIGN((int)target_width_f,  align_size);
+        int aligned_height = CLIP_ALIGN((int)target_height_f, align_size);
+
+        return {aligned_width, aligned_height};
+    }
+
+private:
+    static inline int clip(int x, int lower, int upper) {
+        return std::max(lower, std::min(x, upper));
+    }
+
+    // Linear interpolation between two points
+    static inline float lerp(float s, float e, float t) {
+        return s + (e - s) * t;
+    }
+};
+
+/**
+ * implementation of LLaVA-UHD:
+ *  - https://arxiv.org/pdf/2403.11703
+ *  - https://github.com/thunlp/LLaVA-UHD
+ *  - https://github.com/thunlp/LLaVA-UHD/blob/302301bc2175f7e717fb8548516188e89f649753/llava_uhd/train/llava-uhd/slice_logic.py#L118
+ *
+ * overview:
+ *   - an image always have a single overview (downscaled image)
+ *   - an image can have 0 or multiple slices, depending on the image size
+ *   - each slice can then be considered as a separate image
+ *
+ * for example:
+ *
+ * [overview] --> [slice 1] --> [slice 2]
+ *           |                |
+ *           +--> [slice 3] --> [slice 4]
+ */
+struct llava_uhd {
+    struct slice_coordinates {
+        int x;
+        int y;
+        clip_image_size size;
+    };
+
+    struct slice_instructions {
+        clip_image_size overview_size; // size of downscaled image
+        clip_image_size refined_size;  // size of image right before slicing (must be multiple of slice size)
+        clip_image_size grid_size;     // grid_size.width * grid_size.height = number of slices
+        std::vector slices;
+        bool padding_refined = false;  // if true, refine image will be padded to the grid size (e.g. llava-1.6)
+    };
+
+    static slice_instructions get_slice_instructions(struct clip_ctx * ctx, const clip_image_size & original_size) {
+        slice_instructions res;
+        const int patch_size      = clip_get_patch_size(ctx);
+        const int slice_size      = clip_get_image_size(ctx);
+        const int original_width  = original_size.width;
+        const int original_height = original_size.height;
+
+        const bool has_slices    = original_size.width > slice_size || original_size.height > slice_size;
+        const bool has_pinpoints = !ctx->model.hparams.image_res_candidates.empty();
+
+        if (!has_slices) {
+            // skip slicing logic
+            res.overview_size = clip_image_size{slice_size, slice_size};
+            res.refined_size  = clip_image_size{0, 0};
+            res.grid_size     = clip_image_size{0, 0};
+
+            return res;
+        }
+
+        if (has_pinpoints) {
+            // has pinpoints, use them to calculate the grid size (e.g. llava-1.6)
+            auto refine_size = llava_uhd::select_best_resolution(
+                original_size,
+                ctx->model.hparams.image_res_candidates);
+            res.overview_size   = clip_image_size{slice_size, slice_size};
+            res.refined_size    = refine_size;
+            res.grid_size       = clip_image_size{0, 0};
+            res.padding_refined = true;
+
+            LOG_DBG("%s: using pinpoints for slicing\n", __func__);
+            LOG_DBG("%s: original size: %d x %d, overview size: %d x %d, refined size: %d x %d\n",
+                    __func__, original_width, original_height,
+                    res.overview_size.width, res.overview_size.height,
+                    res.refined_size.width,  res.refined_size.height);
+
+            for (int y = 0; y < refine_size.height; y += slice_size) {
+                for (int x = 0; x < refine_size.width; x += slice_size) {
+                    slice_coordinates slice;
+                    slice.x = x;
+                    slice.y = y;
+                    slice.size.width  = std::min(slice_size, refine_size.width  - x);
+                    slice.size.height = std::min(slice_size, refine_size.height - y);
+                    res.slices.push_back(slice);
+                    LOG_DBG("%s: slice %d: x=%d, y=%d, size=%dx%d\n",
+                            __func__, (int)res.slices.size() - 1,
+                            slice.x, slice.y, slice.size.width, slice.size.height);
+                }
+            }
+
+            res.grid_size.height = refine_size.height / slice_size;
+            res.grid_size.width  = refine_size.width  / slice_size;
+            LOG_DBG("%s: grid size: %d x %d\n", __func__, res.grid_size.width, res.grid_size.height);
+
+            return res;
+        }
+
+        // no pinpoints, dynamically calculate the grid size (e.g. minicpmv)
+
+        auto best_size    = get_best_resize(original_size, slice_size, patch_size, !has_slices);
+        res.overview_size = best_size;
+
+        {
+            const int max_slice_nums = 9; // TODO: this is only used by minicpmv, maybe remove it
+            const float log_ratio = log((float)original_width / original_height);
+            const float ratio = (float)original_width * original_height / (slice_size * slice_size);
+            const int multiple = fmin(ceil(ratio), max_slice_nums);
+
+            auto best_grid   = get_best_grid(max_slice_nums, multiple, log_ratio);
+            auto refine_size = get_refine_size(original_size, best_grid, slice_size, patch_size, true);
+            res.grid_size    = best_grid;
+            res.refined_size = refine_size;
+
+            LOG_DBG("%s: original size: %d x %d, overview size: %d x %d, refined size: %d x %d, grid size: %d x %d\n",
+                    __func__, original_width, original_height,
+                    res.overview_size.width, res.overview_size.height,
+                    res.refined_size.width, res.refined_size.height,
+                    res.grid_size.width, res.grid_size.height);
+
+            int width  = refine_size.width;
+            int height = refine_size.height;
+            int grid_x = int(width  / best_grid.width);
+            int grid_y = int(height / best_grid.height);
+            for (int patches_y = 0,                    ic = 0;
+                    patches_y < refine_size.height && ic < best_grid.height;
+                    patches_y += grid_y,              ic += 1) {
+                for (int patches_x = 0,                   jc = 0;
+                        patches_x < refine_size.width && jc < best_grid.width;
+                        patches_x += grid_x,             jc += 1) {
+                    slice_coordinates slice;
+                    slice.x = patches_x;
+                    slice.y = patches_y;
+                    slice.size.width  = grid_x;
+                    slice.size.height = grid_y;
+                    res.slices.push_back(slice);
+                    LOG_DBG("%s: slice %d: x=%d, y=%d, size=%dx%d\n",
+                            __func__, (int)res.slices.size() - 1,
+                            slice.x, slice.y, slice.size.width, slice.size.height);
+                }
+            }
+        }
+
+        return res;
+    }
+
+    static std::vector slice_image(const clip_image_u8 * img, const slice_instructions & inst) {
+        std::vector output;
+
+        // resize to overview size
+        clip_image_u8_ptr resized_img(clip_image_u8_init());
+        image_manipulation::bicubic_resize(*img, *resized_img, inst.overview_size.width, inst.overview_size.height);
+        output.push_back(std::move(resized_img));
+        if (inst.slices.empty()) {
+            // no slices, just return the resized image
+            return output;
+        }
+
+        // resize to refined size
+        clip_image_u8_ptr refined_img(clip_image_u8_init());
+        if (inst.padding_refined) {
+            image_manipulation::resize_and_pad_image(*img, *refined_img, inst.refined_size);
+        } else {
+            image_manipulation::bilinear_resize(*img, *refined_img, inst.refined_size.width, inst.refined_size.height);
+        }
+
+        // create slices
+        for (const auto & slice : inst.slices) {
+            int x = slice.x;
+            int y = slice.y;
+            int w = slice.size.width;
+            int h = slice.size.height;
+
+            clip_image_u8_ptr img_slice(clip_image_u8_init());
+            image_manipulation::crop_image(*refined_img, *img_slice, x, y, w, h);
+            output.push_back(std::move(img_slice));
+        }
+
+        return output;
+    }
+
+private:
+    static clip_image_size get_best_resize(const clip_image_size & original_size, int scale_resolution, int patch_size, bool allow_upscale = false) {
+        int width  = original_size.width;
+        int height = original_size.height;
+        if ((width * height > scale_resolution * scale_resolution) || allow_upscale) {
+            float r = static_cast(width) / height;
+            height  = static_cast(scale_resolution / std::sqrt(r));
+            width   = static_cast(height * r);
+        }
+        clip_image_size res;
+        res.width  = ensure_divide(width,  patch_size);
+        res.height = ensure_divide(height, patch_size);
+        return res;
+    }
+
+    static clip_image_size resize_maintain_aspect_ratio(const clip_image_size & orig, const clip_image_size & target_max) {
+        float scale_width  = static_cast(target_max.width)  / orig.width;
+        float scale_height = static_cast(target_max.height) / orig.height;
+        float scale = std::min(scale_width, scale_height);
+        return clip_image_size{
+            static_cast(orig.width  * scale),
+            static_cast(orig.height * scale),
+        };
+    }
+
+    /**
+     * Selects the best resolution from a list of possible resolutions based on the original size.
+     *
+     * For example, when given a list of resolutions:
+     *  - 100x100
+     *  - 200x100
+     *  - 100x200
+     *  - 200x200
+     *
+     * And an input image of size 111x200, then 100x200 is the best fit (least wasted resolution).
+     *
+     * @param original_size The original size of the image
+     * @param possible_resolutions A list of possible resolutions
+     * @return The best fit resolution
+     */
+    static clip_image_size select_best_resolution(const clip_image_size & original_size, const std::vector & possible_resolutions) {
+        clip_image_size best_fit;
+        int min_wasted_area = std::numeric_limits::max();
+        int max_effective_resolution = 0;
+
+        for (const clip_image_size & candidate : possible_resolutions) {
+            auto target_size = resize_maintain_aspect_ratio(original_size, candidate);
+            int effective_resolution = std::min(
+                target_size.width * target_size.height,
+                original_size.width * original_size.height);
+            int wasted_area = (candidate.width * candidate.height) - effective_resolution;
+
+            if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_area < min_wasted_area)) {
+                max_effective_resolution = effective_resolution;
+                min_wasted_area = wasted_area;
+                best_fit = candidate;
+            }
+
+            LOG_DBG("%s: candidate: %d x %d, target: %d x %d, wasted: %d, effective: %d\n", __func__, candidate.width, candidate.height, target_size.width, target_size.height, wasted_area, effective_resolution);
+        }
+
+        return best_fit;
+    }
+
+    static int ensure_divide(int length, int patch_size) {
+        return std::max(static_cast(std::round(static_cast(length) / patch_size) * patch_size), patch_size);
+    }
+
+    static clip_image_size get_refine_size(const clip_image_size & original_size, const clip_image_size & grid, int scale_resolution, int patch_size, bool allow_upscale = false) {
+        int width  = original_size.width;
+        int height = original_size.height;
+        int grid_x = grid.width;
+        int grid_y = grid.height;
+
+        int refine_width  = ensure_divide(width, grid_x);
+        int refine_height = ensure_divide(height, grid_y);
+
+        clip_image_size grid_size;
+        grid_size.width  = refine_width  / grid_x;
+        grid_size.height = refine_height / grid_y;
+
+        auto best_grid_size  = get_best_resize(grid_size, scale_resolution, patch_size, allow_upscale);
+        int best_grid_width  = best_grid_size.width;
+        int best_grid_height = best_grid_size.height;
+
+        clip_image_size refine_size;
+        refine_size.width  = best_grid_width  * grid_x;
+        refine_size.height = best_grid_height * grid_y;
+        return refine_size;
+    }
+
+    static clip_image_size get_best_grid(const int max_slice_nums, const int multiple, const float log_ratio) {
+        std::vector candidate_split_grids_nums;
+        for (int i : {multiple - 1, multiple, multiple + 1}) {
+            if (i == 1 || i > max_slice_nums) {
+                continue;
+            }
+            candidate_split_grids_nums.push_back(i);
+        }
+
+        std::vector candidate_grids;
+        for (int split_grids_nums : candidate_split_grids_nums) {
+            int m = 1;
+            while (m <= split_grids_nums) {
+                if (split_grids_nums % m == 0) {
+                    candidate_grids.push_back(clip_image_size{m, split_grids_nums / m});
+                }
+                ++m;
+            }
+        }
+
+        clip_image_size best_grid{1, 1};
+        float min_error = std::numeric_limits::infinity();
+        for (const auto& grid : candidate_grids) {
+            float error = std::abs(log_ratio - std::log(1.0 * grid.width / grid.height));
+            if (error < min_error) {
+                best_grid = grid;
+                min_error = error;
+            }
+        }
+        return best_grid;
+    }
+};
+
+// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
+// res_imgs memory is being allocated here, previous allocations will be freed if found
+bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) {
+    clip_image_size original_size{img->nx, img->ny};
+    bool pad_to_square = true;
+    auto & params = ctx->model.hparams;
+    // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing
+    if (params.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD) {
+        pad_to_square = false;
+    }
+
+    if (clip_is_minicpmv(ctx)) {
+        auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
+        std::vector imgs = llava_uhd::slice_image(img, inst);
+
+        for (size_t i = 0; i < imgs.size(); ++i) {
+            // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp");
+            clip_image_f32_ptr res(clip_image_f32_init());
+            normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
+            res_imgs->entries.push_back(std::move(res));
+        }
+
+        res_imgs->grid_x = inst.grid_size.width;
+        res_imgs->grid_y = inst.grid_size.height;
+        return true;
+
+    } else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) {
+        clip_image_u8 resized;
+        auto patch_size = params.patch_size * 2;
+        auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, patch_size, params.image_size);
+        image_manipulation::bicubic_resize(*img, resized, new_size.width, new_size.height);
+
+        clip_image_f32_ptr img_f32(clip_image_f32_init());
+        // clip_image_f32_ptr res(clip_image_f32_init());
+        normalize_image_u8_to_f32(resized, *img_f32, params.image_mean, params.image_std);
+        // res_imgs->data[0] = *res;
+        res_imgs->entries.push_back(std::move(img_f32));
+        return true;
+    }
+    else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE
+            || ctx->proj_type() == PROJECTOR_TYPE_GEMMA3
+            || ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3
+            || ctx->proj_type() == PROJECTOR_TYPE_INTERNVL // TODO @ngxson : support dynamic resolution
+    ) {
+        clip_image_u8 resized_image;
+        int sz = params.image_size;
+        image_manipulation::resize_and_pad_image(*img, resized_image, {sz, sz});
+        clip_image_f32_ptr img_f32(clip_image_f32_init());
+        //clip_image_save_to_bmp(resized_image, "resized.bmp");
+        normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std);
+        res_imgs->entries.push_back(std::move(img_f32));
+        return true;
+
+    } else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL) {
+        clip_image_u8 resized_image;
+        auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size);
+        image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height);
+        clip_image_f32_ptr img_f32(clip_image_f32_init());
+        normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std);
+        res_imgs->entries.push_back(std::move(img_f32));
+        return true;
+
+    } else if (ctx->proj_type() == PROJECTOR_TYPE_LLAMA4) {
+        GGML_ASSERT(!params.image_res_candidates.empty());
+        auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
+        std::vector imgs = llava_uhd::slice_image(img, inst);
+
+        for (size_t i = 0; i < imgs.size(); ++i) {
+            clip_image_f32_ptr res(clip_image_f32_init());
+            normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
+            res_imgs->entries.push_back(std::move(res));
+        }
+
+        res_imgs->grid_x = inst.grid_size.width;
+        res_imgs->grid_y = inst.grid_size.height;
+        return true;
+
+    }
+
+    // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
+    // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
+
+    clip_image_u8_ptr temp(clip_image_u8_init()); // we will keep the input image data here temporarily
+
+    if (pad_to_square) {
+        // for llava-1.5, we resize image to a square, and pad the shorter side with a background color
+        // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
+        const int longer_side = std::max(img->nx, img->ny);
+        temp->nx = longer_side;
+        temp->ny = longer_side;
+        temp->buf.resize(3 * longer_side * longer_side);
+
+        // background color in RGB from LLaVA (this is the mean rgb color * 255)
+        const std::array pad_color = {122, 116, 104};
+
+        // resize the image to the target_size
+        image_manipulation::resize_and_pad_image(*img, *temp, clip_image_size{params.image_size, params.image_size}, pad_color);
+
+        clip_image_f32_ptr res(clip_image_f32_init());
+        normalize_image_u8_to_f32(*temp, *res, params.image_mean, params.image_std);
+        res_imgs->entries.push_back(std::move(res));
+        return true;
+
+    } else if (!params.image_res_candidates.empty()) {
+        // "spatial_unpad" with "anyres" processing for llava-1.6
+        auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
+        std::vector imgs = llava_uhd::slice_image(img, inst);
+
+        for (size_t i = 0; i < imgs.size(); ++i) {
+            // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp");
+            clip_image_f32_ptr res(clip_image_f32_init());
+            normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
+            res_imgs->entries.push_back(std::move(res));
+        }
+
+        return true;
+
+    }
+
+    GGML_ASSERT(false && "Unknown image preprocessing type");
+}
+
+ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx) {
+    return ctx->model.image_newline;
+}
+
+void clip_free(clip_ctx * ctx) {
+    if (ctx == nullptr) {
+        return;
+    }
+    delete ctx;
+}
+
+// deprecated
+size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
+    const int32_t nx = ctx->model.hparams.image_size;
+    const int32_t ny = ctx->model.hparams.image_size;
+    return clip_embd_nbytes_by_img(ctx, nx, ny);
+}
+
+size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_w, int img_h) {
+    clip_image_f32 img;
+    img.nx = img_w;
+    img.ny = img_h;
+    return clip_n_output_tokens(ctx, &img) * clip_n_mmproj_embd(ctx) * sizeof(float);
+}
+
+int32_t clip_get_image_size(const struct clip_ctx * ctx) {
+    return ctx->model.hparams.image_size;
+}
+
+int32_t clip_get_patch_size(const struct clip_ctx * ctx) {
+    return ctx->model.hparams.patch_size;
+}
+
+int32_t clip_get_hidden_size(const struct clip_ctx * ctx) {
+    return ctx->model.hparams.n_embd;
+}
+
+const char * clip_patch_merge_type(const struct clip_ctx * ctx) {
+    return ctx->model.hparams.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD ? "spatial_unpad" : "flat";
+}
+
+int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
+    const auto & params = ctx->model.hparams;
+    const int n_total = clip_n_output_tokens(ctx, img);
+    if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) {
+        return img->nx / (params.patch_size * 2) + (int)(img->nx % params.patch_size > 0);
+    }
+    return n_total;
+}
+
+int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
+    const auto & params = ctx->model.hparams;
+    if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL) {
+        return img->ny / (params.patch_size * 2) + (int)(img->ny % params.patch_size > 0);
+    }
+    return 1;
+}
+
+int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
+    const auto & params = ctx->model.hparams;
+
+    // only for models using fixed size square images
+    int n_patches_sq = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
+
+    projector_type proj = ctx->proj_type();
+
+    switch (proj) {
+        case PROJECTOR_TYPE_MLP:
+        case PROJECTOR_TYPE_MLP_NORM:
+            {
+                // do nothing
+            } break;
+        case PROJECTOR_TYPE_LDP:
+        case PROJECTOR_TYPE_LDPV2:
+        case PROJECTOR_TYPE_GLM_EDGE:
+            {
+                n_patches_sq /= 4;
+                if (ctx->model.mm_glm_tok_boi) {
+                    n_patches_sq += 2; // for BOI and EOI token embeddings
+                }
+            } break;
+        case PROJECTOR_TYPE_MINICPMV:
+            {
+                if (params.minicpmv_version == 2) {
+                    n_patches_sq = 96;
+                } else if (params.minicpmv_version == 3) {
+                    n_patches_sq = 64;
+                } else if (params.minicpmv_version == 4) {
+                    n_patches_sq = 64;
+                } else {
+                    GGML_ABORT("Unknown minicpmv version");
+                }
+            } break;
+        case PROJECTOR_TYPE_QWEN2VL:
+        case PROJECTOR_TYPE_QWEN25VL:
+            {
+                // dynamic size
+                int patch_size = params.patch_size * 2;
+                int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
+                int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
+                n_patches_sq = x_patch * y_patch;
+            } break;
+        case PROJECTOR_TYPE_GEMMA3:
+            {
+                int n_per_side = params.image_size / params.patch_size;
+                int n_per_side_2d_pool = n_per_side / params.proj_scale_factor;
+                n_patches_sq = n_per_side_2d_pool * n_per_side_2d_pool;
+            } break;
+        case PROJECTOR_TYPE_IDEFICS3:
+        case PROJECTOR_TYPE_INTERNVL:
+            {
+                // both W and H are divided by proj_scale_factor
+                n_patches_sq /= (params.proj_scale_factor * params.proj_scale_factor);
+            } break;
+        case PROJECTOR_TYPE_PIXTRAL:
+            {
+                // dynamic size
+                int n_merge = params.spatial_merge_size;
+                int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1);
+                int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1);
+                n_patches_sq = n_patches_y * n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
+            } break;
+        case PROJECTOR_TYPE_LLAMA4:
+            {
+                int scale_factor = ctx->model.hparams.proj_scale_factor;
+                n_patches_sq /= (scale_factor * scale_factor);
+            } break;
+        case PROJECTOR_TYPE_ULTRAVOX:
+            {
+                const int proj_stack_factor = ctx->model.hparams.proj_stack_factor;
+                const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor);
+                n_patches_sq = n_len / proj_stack_factor / 2;
+            } break;
+        case PROJECTOR_TYPE_QWEN2A:
+            {
+                // divide by 2 because of whisper
+                // another divide by 2 because of nn.AvgPool1d(2, stride=2)
+                n_patches_sq = img->nx / 4;
+            } break;
+        default:
+            GGML_ABORT("unsupported projector type");
+    }
+
+    return n_patches_sq;
+}
+
+static std::vector>> get_1d_sincos_pos_embed_from_grid_new(int embed_dim, const std::vector> & pos) {
+    assert(embed_dim % 2 == 0);
+    int H = pos.size();
+    int W = pos[0].size();
+
+    std::vector omega(embed_dim / 2);
+    for (int i = 0; i < embed_dim / 2; ++i) {
+        omega[i] = 1.0 / pow(10000.0, static_cast(i) / (embed_dim / 2));
+    }
+
+    std::vector>> emb(H, std::vector>(W, std::vector(embed_dim)));
+    for (int h = 0; h < H; ++h) {
+        for (int w = 0; w < W; ++w) {
+            for (int d = 0; d < embed_dim / 2; ++d) {
+                float out_value = pos[h][w] * omega[d];
+                emb[h][w][d] = sin(out_value);
+                emb[h][w][d + embed_dim / 2] = cos(out_value);
+            }
+        }
+    }
+
+    return emb;
+}
+
+static std::vector>> get_2d_sincos_pos_embed_from_grid(int embed_dim, const std::vector>> & grid) {
+    assert(embed_dim % 2 == 0);
+    std::vector>> emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[0]); // (H, W, D/2)
+    std::vector>> emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, grid[1]); // (H, W, D/2)
+
+    int H = emb_h.size();
+    int W = emb_h[0].size();
+    std::vector>> emb(H, std::vector>(W, std::vector(embed_dim)));
+
+    for (int h = 0; h < H; ++h) {
+        for (int w = 0; w < W; ++w) {
+            for (int d = 0; d < embed_dim / 2; ++d) {
+                emb[h][w][d] = emb_h[h][w][d];
+                emb[h][w][d + embed_dim / 2] = emb_w[h][w][d];
+            }
+        }
+    }
+    return emb;
+}
+
+static std::vector> get_2d_sincos_pos_embed(int embed_dim, const std::pair image_size) {
+    int grid_h_size = image_size.first;
+    int grid_w_size = image_size.second;
+
+    std::vector grid_h(grid_h_size);
+    std::vector grid_w(grid_w_size);
+
+    for (int i = 0; i < grid_h_size; ++i) {
+        grid_h[i] = static_cast(i);
+    }
+    for (int i = 0; i < grid_w_size; ++i) {
+        grid_w[i] = static_cast(i);
+    }
+
+    std::vector> grid(grid_h_size, std::vector(grid_w_size));
+    for (int h = 0; h < grid_h_size; ++h) {
+        for (int w = 0; w < grid_w_size; ++w) {
+            grid[h][w] = grid_w[w];
+        }
+    }
+    std::vector>> grid_2d = {grid, grid};
+    for (int h = 0; h < grid_h_size; ++h) {
+        for (int w = 0; w < grid_w_size; ++w) {
+            grid_2d[0][h][w] = grid_h[h];
+            grid_2d[1][h][w] = grid_w[w];
+        }
+    }
+
+    std::vector>> pos_embed_3d = get_2d_sincos_pos_embed_from_grid(embed_dim, grid_2d);
+
+    int H = image_size.first;
+    int W = image_size.second;
+    std::vector> pos_embed_2d(H * W, std::vector(embed_dim));
+    for (int h = 0; h < H; ++h) {
+        for (int w = 0; w < W; ++w) {
+            pos_embed_2d[w * H + h] = pos_embed_3d[h][w];
+        }
+    }
+
+    return pos_embed_2d;
+}
+
+bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) {
+    clip_image_f32_batch imgs;
+    clip_image_f32_ptr img_copy(clip_image_f32_init());
+    *img_copy = *img;
+    imgs.entries.push_back(std::move(img_copy));
+
+    return clip_image_batch_encode(ctx, n_threads, &imgs, vec);
+}
+
+bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs_c_ptr, float * vec) {
+    const clip_image_f32_batch & imgs = *imgs_c_ptr;
+    int batch_size = imgs.entries.size();
+
+    // TODO @ngxson : implement batch size > 1 as a loop
+    //                we don't need true batching support because the cgraph will gonna be big anyway
+    if (batch_size != 1) {
+        return false; // only support batch size of 1
+    }
+
+    // build the inference graph
+    ctx->debug_print_tensors.clear();
+    ggml_backend_sched_reset(ctx->sched.get());
+    ggml_cgraph * gf = clip_image_build_graph(ctx, imgs);
+    ggml_backend_sched_alloc_graph(ctx->sched.get(), gf);
+
+    // set inputs
+    const auto & model   = ctx->model;
+    const auto & hparams = model.hparams;
+
+    const int image_size_width  = imgs.entries[0]->nx;
+    const int image_size_height = imgs.entries[0]->ny;
+
+    const int patch_size    = hparams.patch_size;
+    const int num_patches   = ((image_size_width / patch_size) * (image_size_height / patch_size));
+    const int n_pos = num_patches + (model.class_embedding ? 1 : 0);
+    const int pos_w = image_size_width  / patch_size;
+    const int pos_h = image_size_height / patch_size;
+
+    const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl
+
+    auto get_inp_tensor = [&gf](const char * name) {
+        ggml_tensor * inp = ggml_graph_get_tensor(gf, name);
+        if (inp == nullptr) {
+            GGML_ABORT("Failed to get tensor %s", name);
+        }
+        if (!(inp->flags & GGML_TENSOR_FLAG_INPUT)) {
+            GGML_ABORT("Tensor %s is not an input tensor", name);
+        }
+        return inp;
+    };
+
+    auto set_input_f32 = [&get_inp_tensor](const char * name, std::vector & values) {
+        ggml_tensor * cur = get_inp_tensor(name);
+        GGML_ASSERT(cur->type == GGML_TYPE_F32);
+        GGML_ASSERT(ggml_nelements(cur) == (int64_t)values.size());
+        ggml_backend_tensor_set(cur, values.data(), 0, ggml_nbytes(cur));
+    };
+
+    auto set_input_i32 = [&get_inp_tensor](const char * name, std::vector & values) {
+        ggml_tensor * cur = get_inp_tensor(name);
+        GGML_ASSERT(cur->type == GGML_TYPE_I32);
+        GGML_ASSERT(ggml_nelements(cur) == (int64_t)values.size());
+        ggml_backend_tensor_set(cur, values.data(), 0, ggml_nbytes(cur));
+    };
+
+    // set input pixel values
+    if (!imgs.is_audio) {
+        size_t nelem = 0;
+        for (const auto & img : imgs.entries) {
+            nelem += img->nx * img->ny * 3;
+        }
+        std::vector inp_raw(nelem);
+
+        // layout of data (note: the channel dim is unrolled to better visualize the layout):
+        //
+        // ┌──W──┐
+        // │     H │  channel = R
+        // ├─────┤ │
+        // │     H │  channel = G
+        // ├─────┤ │
+        // │     H │  channel = B
+        // └─────┘ │
+        //   ──────┘ x B
+
+        for (size_t i = 0; i < imgs.entries.size(); i++) {
+            const int nx = imgs.entries[i]->nx;
+            const int ny = imgs.entries[i]->ny;
+            const int n = nx * ny;
+
+            for (int b = 0; b < batch_size; b++) {
+                float * batch_entry = inp_raw.data() + b * (3*n);
+                for (int y = 0; y < ny; y++) {
+                    for (int x = 0; x < nx; x++) {
+                        size_t base_src = 3*(y * nx + x); // idx of the first channel
+                        size_t base_dst =    y * nx + x;  // idx of the first channel
+                        batch_entry[      base_dst] = imgs.entries[b]->buf[base_src    ];
+                        batch_entry[1*n + base_dst] = imgs.entries[b]->buf[base_src + 1];
+                        batch_entry[2*n + base_dst] = imgs.entries[b]->buf[base_src + 2];
+                    }
+                }
+            }
+        }
+        set_input_f32("inp_raw", inp_raw);
+
+    } else {
+        // audio input
+        GGML_ASSERT(imgs.entries.size() == 1);
+        const auto & mel_inp = imgs.entries[0];
+        const int n_step = mel_inp->nx;
+        const int n_mel  = mel_inp->ny;
+        std::vector inp_raw(n_step * n_mel);
+        std::memcpy(inp_raw.data(), mel_inp->buf.data(), n_step * n_mel * sizeof(float));
+        set_input_f32("inp_raw", inp_raw);
+    }
+
+    // set input per projector
+    switch (ctx->model.proj_type) {
+        case PROJECTOR_TYPE_MINICPMV:
+            {
+                // inspired from siglip:
+                //    -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit
+                //    -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
+                std::vector positions(pos_h * pos_w);
+                int bucket_coords_h[1024];
+                int bucket_coords_w[1024];
+                for (int i = 0; i < pos_h; i++){
+                    bucket_coords_h[i] = std::floor(70.0*i/pos_h);
+                }
+                for (int i = 0; i < pos_w; i++){
+                    bucket_coords_w[i] = std::floor(70.0*i/pos_w);
+                }
+                for (int i = 0, id = 0; i < pos_h; i++){
+                    for (int j = 0; j < pos_w; j++){
+                        positions[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j];
+                    }
+                }
+                set_input_i32("positions", positions);
+
+                // inspired from resampler of Qwen-VL:
+                //    -> https://huggingface.co/Qwen/Qwen-VL/tree/main
+                //    -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23
+                int embed_dim = clip_n_mmproj_embd(ctx);
+
+                // TODO @ngxson : this is very inefficient, can we do this using ggml_sin and ggml_cos?
+                auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
+
+                std::vector pos_embed(embed_dim * pos_w * pos_h);
+                for(int i = 0; i < pos_w * pos_h; ++i){
+                    for(int j = 0; j < embed_dim; ++j){
+                        pos_embed[i * embed_dim + j] = pos_embed_t[i][j];
+                    }
+                }
+
+                set_input_f32("pos_embed", pos_embed);
+            } break;
+        case PROJECTOR_TYPE_QWEN2VL:
+            {
+                const int merge_ratio = 2;
+                const int pw = image_size_width  / patch_size;
+                const int ph = image_size_height / patch_size;
+                std::vector positions(n_pos * 4);
+                int ptr = 0;
+                for (int y = 0; y < ph; y += merge_ratio) {
+                    for (int x = 0; x < pw; x += merge_ratio) {
+                        for (int dy = 0; dy < 2; dy++) {
+                            for (int dx = 0; dx < 2; dx++) {
+                                positions[                  ptr] = y + dy;
+                                positions[    num_patches + ptr] = x + dx;
+                                positions[2 * num_patches + ptr] = y + dy;
+                                positions[3 * num_patches + ptr] = x + dx;
+                                ptr++;
+                            }
+                        }
+                    }
+                }
+
+                set_input_i32("positions", positions);
+            } break;
+        case PROJECTOR_TYPE_QWEN25VL:
+            {
+                // pw * ph = number of tokens output by ViT after apply patch merger
+                // ipw * ipw = number of vision token been processed inside ViT
+                const int merge_ratio = 2;
+                const int pw  = image_size_width  / patch_size / merge_ratio;
+                const int ph  = image_size_height / patch_size / merge_ratio;
+                const int ipw = image_size_width  / patch_size;
+                const int iph = image_size_height / patch_size;
+
+                std::vector idx    (ph * pw);
+                std::vector inv_idx(ph * pw);
+
+                if (use_window_attn) {
+                    const int attn_window_size = 112;
+                    const int grid_window = attn_window_size / patch_size / merge_ratio;
+                    int dst = 0;
+                    // [num_vision_tokens, num_vision_tokens] attention mask tensor
+                    std::vector mask(pow(ipw * iph, 2), std::numeric_limits::lowest());
+                    int mask_row = 0;
+
+                    for (int y = 0; y < ph; y += grid_window) {
+                        for (int x = 0; x < pw; x += grid_window) {
+                            const int win_h = std::min(grid_window, ph - y);
+                            const int win_w = std::min(grid_window, pw - x);
+                            const int dst_0 = dst;
+                            // group all tokens belong to the same window togather (to a continue range)
+                            for (int dy = 0; dy < win_h; dy++) {
+                                for (int dx = 0; dx < win_w; dx++) {
+                                    const int src = (y + dy) * pw + (x + dx);
+                                    GGML_ASSERT(src < (int)idx.size());
+                                    GGML_ASSERT(dst < (int)inv_idx.size());
+                                    idx    [src] = dst;
+                                    inv_idx[dst] = src;
+                                    dst++;
+                                }
+                            }
+
+                            for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
+                                int row_offset = mask_row * (ipw * iph);
+                                std::fill(
+                                    mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
+                                    mask.begin() + row_offset + (dst   * merge_ratio * merge_ratio),
+                                    0.0);
+                                mask_row++;
+                            }
+                        }
+                    }
+
+                    set_input_i32("window_idx",     idx);
+                    set_input_i32("inv_window_idx", inv_idx);
+                    set_input_f32("window_mask",    mask);
+                } else {
+                    for (int i = 0; i < ph * pw; i++) {
+                        idx[i] = i;
+                    }
+                }
+
+                const int mpow = merge_ratio * merge_ratio;
+                std::vector positions(n_pos * 4);
+
+                int ptr = 0;
+                for (int y = 0; y < iph; y += merge_ratio) {
+                    for (int x = 0; x < ipw; x += merge_ratio) {
+                        for (int dy = 0; dy < 2; dy++) {
+                            for (int dx = 0; dx < 2; dx++) {
+                                auto remap = idx[ptr / mpow];
+                                remap = (remap * mpow) + (ptr % mpow);
+
+                                positions[                  remap] = y + dy;
+                                positions[    num_patches + remap] = x + dx;
+                                positions[2 * num_patches + remap] = y + dy;
+                                positions[3 * num_patches + remap] = x + dx;
+                                ptr++;
+                            }
+                        }
+                    }
+                }
+
+                set_input_i32("positions", positions);
+            } break;
+        case PROJECTOR_TYPE_PIXTRAL:
+            {
+                // set the 2D positions
+                int n_patches_per_col = image_size_width / patch_size;
+                std::vector pos_data(n_pos);
+                // dimension H
+                for (int i = 0; i < n_pos; i++) {
+                    pos_data[i] = i / n_patches_per_col;
+                }
+                set_input_i32("pos_h", pos_data);
+                // dimension W
+                for (int i = 0; i < n_pos; i++) {
+                    pos_data[i] = i % n_patches_per_col;
+                }
+                set_input_i32("pos_w", pos_data);
+            } break;
+        case PROJECTOR_TYPE_GLM_EDGE:
+        {
+            // llava and other models
+            std::vector positions(n_pos);
+            for (int i = 0; i < n_pos; i++) {
+                positions[i] = i;
+            }
+            set_input_i32("positions", positions);
+        } break;
+        case PROJECTOR_TYPE_MLP:
+        case PROJECTOR_TYPE_MLP_NORM:
+        case PROJECTOR_TYPE_LDP:
+        case PROJECTOR_TYPE_LDPV2:
+            {
+                // llava and other models
+                std::vector positions(n_pos);
+                for (int i = 0; i < n_pos; i++) {
+                    positions[i] = i;
+                }
+                set_input_i32("positions", positions);
+
+                // The patches vector is used to get rows to index into the embeds with;
+                // we should skip dim 0 only if we have CLS to avoid going out of bounds
+                // when retrieving the rows.
+                int patch_offset = model.class_embedding ? 1 : 0;
+                std::vector patches(num_patches);
+                for (int i = 0; i < num_patches; i++) {
+                    patches[i] = i + patch_offset;
+                }
+                set_input_i32("patches", patches);
+            } break;
+        case PROJECTOR_TYPE_GEMMA3:
+        case PROJECTOR_TYPE_IDEFICS3:
+        case PROJECTOR_TYPE_INTERNVL:
+        case PROJECTOR_TYPE_QWEN2A:
+        case PROJECTOR_TYPE_ULTRAVOX:
+            {
+                // do nothing
+            } break;
+        case PROJECTOR_TYPE_LLAMA4:
+            {
+                // set the 2D positions
+                int n_patches_per_col = image_size_width / patch_size;
+                std::vector pos_data(num_patches + 1, 0); // +1 for the [CLS] token
+                // last pos is always kept 0, it's for CLS
+                // dimension H
+                for (int i = 0; i < num_patches; i++) {
+                    pos_data[i] = (i / n_patches_per_col) + 1;
+                }
+                set_input_i32("pos_h", pos_data);
+                // dimension W
+                for (int i = 0; i < num_patches; i++) {
+                    pos_data[i] = (i % n_patches_per_col) + 1;
+                }
+                set_input_i32("pos_w", pos_data);
+            } break;
+        default:
+            GGML_ABORT("Unknown projector type");
+    }
+
+    // ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads);
+    ggml_backend_dev_t dev = ggml_backend_get_device(ctx->backend_cpu);
+    ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
+    if (reg) {
+        auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
+        if (ggml_backend_set_n_threads_fn) {
+            ggml_backend_set_n_threads_fn(ctx->backend_cpu, n_threads);
+        }
+    }
+
+    auto status = ggml_backend_sched_graph_compute(ctx->sched.get(), gf);
+    if (status != GGML_STATUS_SUCCESS) {
+        LOG_ERR("%s: ggml_backend_sched_graph_compute failed with error %d\n", __func__, status);
+        return false;
+    }
+
+    // print debug nodes
+    if (ctx->debug_graph) {
+        LOG_INF("\n\n---\n\n");
+        LOG_INF("\n\nDebug graph:\n\n");
+        for (ggml_tensor * t : ctx->debug_print_tensors) {
+            std::vector data(ggml_nbytes(t));
+            ggml_backend_tensor_get(t, data.data(), 0, ggml_nbytes(t));
+            print_tensor_shape(t);
+            print_tensor_data(t, data.data(), 3);
+        }
+    }
+
+    // the last node is the embedding tensor
+    ggml_tensor * embeddings = ggml_graph_node(gf, -1);
+
+    // sanity check (only support batch size of 1 for now)
+    const int n_tokens_out = embeddings->ne[1];
+    const int expected_n_tokens_out = clip_n_output_tokens(ctx, imgs.entries[0].get());
+    if (n_tokens_out != expected_n_tokens_out) {
+        LOG_ERR("%s: expected output %d tokens, got %d\n", __func__, expected_n_tokens_out, n_tokens_out);
+        GGML_ABORT("Invalid number of output tokens");
+    }
+
+    // copy the embeddings to the location passed by the user
+    ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
+
+    return true;
+}
+
+int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
+    const auto & hparams = ctx->model.hparams;
+    switch (ctx->model.proj_type) {
+        case PROJECTOR_TYPE_LDP:
+            return ctx->model.mm_model_block_1_block_2_1_b->ne[0];
+        case PROJECTOR_TYPE_LDPV2:
+            return ctx->model.mm_model_peg_0_b->ne[0];
+        case PROJECTOR_TYPE_MLP:
+        case PROJECTOR_TYPE_PIXTRAL:
+            return ctx->model.mm_2_w->ne[1];
+        case PROJECTOR_TYPE_MLP_NORM:
+            return ctx->model.mm_3_b->ne[0];
+        case PROJECTOR_TYPE_MINICPMV:
+            if (hparams.minicpmv_version == 2) {
+                return 4096;
+            } else if (hparams.minicpmv_version == 3) {
+                return 3584;
+            } else if (hparams.minicpmv_version == 4) {
+                return 3584;
+            }
+            GGML_ABORT("Unknown minicpmv version");
+        case PROJECTOR_TYPE_GLM_EDGE:
+            return ctx->model.mm_model_mlp_3_w->ne[1];
+        case PROJECTOR_TYPE_QWEN2VL:
+        case PROJECTOR_TYPE_QWEN25VL:
+            return ctx->model.mm_1_b->ne[0];
+        case PROJECTOR_TYPE_GEMMA3:
+            return ctx->model.mm_input_proj_w->ne[0];
+        case PROJECTOR_TYPE_IDEFICS3:
+            return ctx->model.projection->ne[1];
+        case PROJECTOR_TYPE_ULTRAVOX:
+            return ctx->model.mm_2_w->ne[1];
+        case PROJECTOR_TYPE_INTERNVL:
+            return ctx->model.mm_3_w->ne[1];
+        case PROJECTOR_TYPE_LLAMA4:
+            return ctx->model.mm_model_proj->ne[1];
+        case PROJECTOR_TYPE_QWEN2A:
+            return ctx->model.mm_fc_w->ne[1];
+        default:
+            GGML_ABORT("Unknown projector type");
+    }
+}
+
+int clip_is_minicpmv(const struct clip_ctx * ctx) {
+    if (ctx->proj_type() == PROJECTOR_TYPE_MINICPMV) {
+        return ctx->model.hparams.minicpmv_version;
+    }
+    return 0;
+}
+
+bool clip_is_glm(const struct clip_ctx * ctx) {
+    return ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE;
+}
+
+bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
+    return ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL
+        || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL;
+}
+
+bool clip_is_llava(const struct clip_ctx * ctx) {
+    return ctx->model.hparams.has_llava_projector;
+}
+
+bool clip_is_gemma3(const struct clip_ctx * ctx) {
+    return ctx->proj_type() == PROJECTOR_TYPE_GEMMA3;
+}
+
+bool clip_has_vision_encoder(const struct clip_ctx * ctx) {
+    return ctx->model.modality == CLIP_MODALITY_VISION;
+}
+
+bool clip_has_audio_encoder(const struct clip_ctx * ctx) {
+    return ctx->model.modality == CLIP_MODALITY_AUDIO;
+}
+
+bool clip_has_whisper_encoder(const struct clip_ctx * ctx) {
+    return ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX
+        || ctx->proj_type() == PROJECTOR_TYPE_QWEN2A;
+}
+
+bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
+    clip_image_f32 clip_img;
+    clip_img.buf.resize(h * w * 3);
+    for (int i = 0; i < h*w*3; i++)
+    {
+        clip_img.buf[i] = img[i];
+    }
+    clip_img.nx = w;
+    clip_img.ny = h;
+    clip_image_encode(ctx, n_threads, &clip_img, vec);
+    return true;
+}
+
+//
+// API used internally with mtmd
+//
+
+projector_type clip_get_projector_type(const struct clip_ctx * ctx) {
+    return ctx->proj_type();
+}
+
+void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel, int n_frames, float * mel) {
+    clip_image_f32 * audio = new clip_image_f32;
+    audio->nx = n_frames;
+    audio->ny = n_mel;
+    audio->buf.resize(n_frames * n_mel);
+    std::memcpy(audio->buf.data(), mel, n_frames * n_mel * sizeof(float));
+
+    batch->entries.push_back(clip_image_f32_ptr(audio));
+    batch->is_audio = true;
+}
diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h
new file mode 100644
index 000000000..08f3efb7b
--- /dev/null
+++ b/tools/mtmd/clip.h
@@ -0,0 +1,111 @@
+#pragma once
+
+#include "ggml.h"
+#include 
+#include 
+
+// !!! Internal header, to be used by mtmd only !!!
+
+struct clip_ctx;
+
+struct clip_image_size {
+    int width;
+    int height;
+};
+
+struct clip_image_f32;
+struct clip_image_u8_batch;
+struct clip_image_f32_batch;
+
+enum clip_modality {
+    CLIP_MODALITY_VISION,
+    CLIP_MODALITY_AUDIO,
+};
+
+struct clip_context_params {
+    bool use_gpu;
+    enum ggml_log_level verbosity;
+};
+
+struct clip_init_result {
+    struct clip_ctx * ctx_v; // vision context
+    struct clip_ctx * ctx_a; // audio context
+};
+
+struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_params);
+
+void clip_free(struct clip_ctx * ctx);
+
+size_t clip_embd_nbytes(const struct clip_ctx * ctx);
+size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_w, int img_h);
+
+int32_t clip_get_image_size (const struct clip_ctx * ctx);
+int32_t clip_get_patch_size (const struct clip_ctx * ctx);
+int32_t clip_get_hidden_size(const struct clip_ctx * ctx);
+
+// TODO: should be enum, not string
+const char * clip_patch_merge_type(const struct clip_ctx * ctx);
+
+int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img);
+
+// for M-RoPE, this will be the number of token positions in X and Y directions
+// for other models, X will be the total number of tokens and Y will be 1
+int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img);
+int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img);
+
+// this should be equal to the embedding dimension of the text model
+int clip_n_mmproj_embd(const struct clip_ctx * ctx);
+
+struct clip_image_size      * clip_image_size_init(void);
+struct clip_image_u8        * clip_image_u8_init (void);
+struct clip_image_f32       * clip_image_f32_init(void);
+struct clip_image_f32_batch * clip_image_f32_batch_init(void); // only used by libllava
+
+// nx, ny are the output image dimensions
+unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny);
+
+void clip_image_size_free (struct clip_image_size * img_size);
+void clip_image_u8_free (struct clip_image_u8  * img);
+void clip_image_f32_free(struct clip_image_f32 * img);
+void clip_image_u8_batch_free (struct clip_image_u8_batch  * batch);
+void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
+
+// use for accessing underlay data of clip_image_f32_batch
+size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch); // equivalent to batch->size()
+size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->nx
+size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->ny
+struct clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data
+
+/**
+ * Build image from pixels decoded by other libraries instead of stb_image.h for better performance.
+ * The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes
+ */
+void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img);
+
+bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
+
+/** interpret bytes as an image file with length bytes_length, and use the result to populate img */
+bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img);
+
+/** preprocess img and store the result in res_imgs, pad_to_square may be overridden to false depending on model configuration */
+bool clip_image_preprocess(struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32_batch * res_imgs );
+
+struct ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx);
+
+bool clip_image_encode      (struct clip_ctx * ctx, int n_threads, struct clip_image_f32 * img, float * vec);
+bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, const struct clip_image_f32_batch * imgs, float * vec);
+
+int clip_is_minicpmv(const struct clip_ctx * ctx);
+bool clip_is_glm(const struct clip_ctx * ctx);
+bool clip_is_qwen2vl(const struct clip_ctx * ctx);
+bool clip_is_llava(const struct clip_ctx * ctx);
+bool clip_is_gemma3(const struct clip_ctx * ctx);
+
+bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);
+
+// use by audio input
+void clip_image_f32_batch_add_mel(struct clip_image_f32_batch * batch, int n_mel, int n_frames, float * mel);
+
+bool clip_has_vision_encoder(const struct clip_ctx * ctx);
+bool clip_has_audio_encoder(const struct clip_ctx * ctx);
+bool clip_has_whisper_encoder(const struct clip_ctx * ctx);
diff --git a/tools/mtmd/deprecation-warning.cpp b/tools/mtmd/deprecation-warning.cpp
new file mode 100644
index 000000000..dded0a56a
--- /dev/null
+++ b/tools/mtmd/deprecation-warning.cpp
@@ -0,0 +1,22 @@
+#include 
+#include 
+
+int main(int argc, char** argv) {
+    std::string filename = "main";
+    if (argc >= 1) {
+        filename = argv[0];
+    }
+
+    // Get only the program name from the full path
+    size_t pos = filename.find_last_of("/\\");
+    if (pos != std::string::npos) {
+        filename = filename.substr(pos+1);
+    }
+
+    fprintf(stdout, "\n");
+    fprintf(stdout, "WARNING: The binary '%s' is deprecated.\n", filename.c_str());
+    fprintf(stdout, "Please use 'llama-mtmd-cli' instead.\n");
+    fprintf(stdout, "\n");
+
+    return EXIT_FAILURE;
+}
diff --git a/examples/llava/convert_image_encoder_to_gguf.py b/tools/mtmd/legacy-models/convert_image_encoder_to_gguf.py
similarity index 100%
rename from examples/llava/convert_image_encoder_to_gguf.py
rename to tools/mtmd/legacy-models/convert_image_encoder_to_gguf.py
diff --git a/examples/llava/glmedge-convert-image-encoder-to-gguf.py b/tools/mtmd/legacy-models/glmedge-convert-image-encoder-to-gguf.py
similarity index 100%
rename from examples/llava/glmedge-convert-image-encoder-to-gguf.py
rename to tools/mtmd/legacy-models/glmedge-convert-image-encoder-to-gguf.py
diff --git a/examples/llava/glmedge-surgery.py b/tools/mtmd/legacy-models/glmedge-surgery.py
similarity index 100%
rename from examples/llava/glmedge-surgery.py
rename to tools/mtmd/legacy-models/glmedge-surgery.py
diff --git a/examples/llava/llava_surgery.py b/tools/mtmd/legacy-models/llava_surgery.py
similarity index 100%
rename from examples/llava/llava_surgery.py
rename to tools/mtmd/legacy-models/llava_surgery.py
diff --git a/examples/llava/llava_surgery_v2.py b/tools/mtmd/legacy-models/llava_surgery_v2.py
similarity index 100%
rename from examples/llava/llava_surgery_v2.py
rename to tools/mtmd/legacy-models/llava_surgery_v2.py
diff --git a/examples/llava/minicpmv-convert-image-encoder-to-gguf.py b/tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py
similarity index 100%
rename from examples/llava/minicpmv-convert-image-encoder-to-gguf.py
rename to tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py
diff --git a/examples/llava/minicpmv-surgery.py b/tools/mtmd/legacy-models/minicpmv-surgery.py
similarity index 100%
rename from examples/llava/minicpmv-surgery.py
rename to tools/mtmd/legacy-models/minicpmv-surgery.py
diff --git a/tools/mtmd/mtmd-audio.cpp b/tools/mtmd/mtmd-audio.cpp
new file mode 100644
index 000000000..4d053895c
--- /dev/null
+++ b/tools/mtmd/mtmd-audio.cpp
@@ -0,0 +1,769 @@
+#include "mtmd-audio.h"
+
+#define _USE_MATH_DEFINES // for M_PI
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+// most of the code here is copied from whisper.cpp
+
+// align x to upper multiple of n
+#define _ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
+
+namespace whisper_preprocessor {
+
+#define SIN_COS_N_COUNT WHISPER_N_FFT
+namespace {
+struct whisper_global_cache {
+    // In FFT, we frequently use sine and cosine operations with the same values.
+    // We can use precalculated values to speed up the process.
+    float sin_vals[SIN_COS_N_COUNT];
+    float cos_vals[SIN_COS_N_COUNT];
+
+    // Hann window (Use cosf to eliminate difference)
+    // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
+    // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
+    float hann_window[WHISPER_N_FFT];
+
+    whisper_global_cache() {
+        fill_sin_cos_table();
+        fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window);
+    }
+
+    void fill_sin_cos_table() {
+        for (int i = 0; i < SIN_COS_N_COUNT; i++) {
+            double theta = (2 * M_PI * i) / SIN_COS_N_COUNT;
+            sin_vals[i] = sinf(theta);
+            cos_vals[i] = cosf(theta);
+        }
+    }
+
+    void fill_hann_window(int length, bool periodic, float * output) {
+        int offset = -1;
+        if (periodic) {
+            offset = 0;
+        }
+        for (int i = 0; i < length; i++) {
+            output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
+        }
+    }
+} global_cache;
+}
+
+// naive Discrete Fourier Transform
+// input is real-valued
+// output is complex-valued
+static void dft(const float* in, int N, float* out) {
+    const int sin_cos_step = SIN_COS_N_COUNT / N;
+
+    for (int k = 0; k < N; k++) {
+        float re = 0;
+        float im = 0;
+
+        for (int n = 0; n < N; n++) {
+            int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N
+            re += in[n]*global_cache.cos_vals[idx]; // cos(t)
+            im -= in[n]*global_cache.sin_vals[idx]; // sin(t)
+        }
+
+        out[k*2 + 0] = re;
+        out[k*2 + 1] = im;
+    }
+}
+
+// Cooley-Tukey FFT
+// poor man's implementation - use something better
+// input is real-valued
+// output is complex-valued
+static void fft(float* in, int N, float* out) {
+    if (N == 1) {
+        out[0] = in[0];
+        out[1] = 0;
+        return;
+    }
+
+    const int half_N = N / 2;
+    if (N - half_N*2 == 1) {
+        dft(in, N, out);
+        return;
+    }
+
+    float* even = in + N;
+    for (int i = 0; i < half_N; ++i) {
+        even[i]= in[2*i];
+    }
+    float* even_fft = out + 2 * N;
+    fft(even, half_N, even_fft);
+
+    float* odd = even;
+    for (int i = 0; i < half_N; ++i) {
+        odd[i] = in[2*i + 1];
+    }
+    float* odd_fft = even_fft + N;
+    fft(odd, half_N, odd_fft);
+
+    const int sin_cos_step = SIN_COS_N_COUNT / N;
+    for (int k = 0; k < half_N; k++) {
+        int idx = k * sin_cos_step; // t = 2*M_PI*k/N
+        float re = global_cache.cos_vals[idx]; // cos(t)
+        float im = -global_cache.sin_vals[idx]; // sin(t)
+
+        float re_odd = odd_fft[2*k + 0];
+        float im_odd = odd_fft[2*k + 1];
+
+        out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
+        out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
+
+        out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
+        out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
+    }
+}
+
+static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector & samples,
+                                              int n_samples, int frame_size, int frame_step, int n_threads,
+                                              const whisper_filters & filters, whisper_mel & mel) {
+    std::vector fft_in(frame_size * 2, 0.0);
+    std::vector fft_out(frame_size * 2 * 2 * 2);
+
+    int n_fft = filters.n_fft;
+    int i = ith;
+
+    // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
+    WHISPER_ASSERT(n_fft == 1 + (frame_size / 2));
+
+    // calculate FFT only when fft_in are not all zero
+    for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) {
+        const int offset = i * frame_step;
+
+        // apply Hann window (~10% faster)
+        for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) {
+            fft_in[j] = hann[j] * samples[offset + j];
+        }
+
+        // fill the rest with zeros
+        if (n_samples - offset < frame_size) {
+            std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0);
+        }
+
+        // FFT
+        fft(fft_in.data(), frame_size, fft_out.data());
+
+        // Calculate modulus^2 of complex numbers
+        // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
+        for (int j = 0; j < n_fft; j++) {
+            fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]);
+        }
+
+        // mel spectrogram
+        for (int j = 0; j < mel.n_mel; j++) {
+            double sum = 0.0;
+            // unroll loop (suggested by GH user @lunixbochs)
+            int k = 0;
+            for (k = 0; k < n_fft - 3; k += 4) {
+                sum +=
+                        fft_out[k + 0] * filters.data[j * n_fft + k + 0] +
+                        fft_out[k + 1] * filters.data[j * n_fft + k + 1] +
+                        fft_out[k + 2] * filters.data[j * n_fft + k + 2] +
+                        fft_out[k + 3] * filters.data[j * n_fft + k + 3];
+            }
+            // handle n_fft remainder
+            for (; k < n_fft; k++) {
+                sum += fft_out[k] * filters.data[j * n_fft + k];
+            }
+            sum = log10(std::max(sum, 1e-10));
+            mel.data[j * mel.n_len + i] = sum;
+        }
+    }
+
+    // Otherwise fft_out are all zero
+    double sum = log10(1e-10);
+    for (; i < mel.n_len; i += n_threads) {
+        for (int j = 0; j < mel.n_mel; j++) {
+            mel.data[j * mel.n_len + i] = sum;
+        }
+    }
+}
+
+// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
+static bool log_mel_spectrogram(
+        const float * samples,
+        const int   n_samples,
+        const int   /*sample_rate*/,
+        const int   frame_size,
+        const int   frame_step,
+        const int   n_mel,
+        const int   n_threads,
+        const whisper_filters & filters,
+        const bool   debug,
+        whisper_mel & mel) {
+    //const int64_t t_start_us = ggml_time_us();
+
+    // Hann window
+    WHISPER_ASSERT(frame_size == WHISPER_N_FFT && "Unsupported frame_size");
+    const float * hann = global_cache.hann_window;
+
+    // Calculate the length of padding
+    int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30;
+    int64_t stage_2_pad = frame_size / 2;
+
+    // Initialize a vector and copy data from C array to it.
+    std::vector samples_padded;
+    samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2);
+    std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad);
+
+    // pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio
+    std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0);
+
+    // reflective pad 200 samples at the beginning of audio
+    std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin());
+
+    mel.n_mel     = n_mel;
+    // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
+    // Calculate number of frames + remove the last frame
+    mel.n_len     = (samples_padded.size() - frame_size) / frame_step;
+    // Calculate semi-padded sample length to ensure compatibility
+    mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
+    mel.data.resize(mel.n_mel * mel.n_len);
+
+    {
+        std::vector workers(n_threads - 1);
+        for (int iw = 0; iw < n_threads - 1; ++iw) {
+            workers[iw] = std::thread(
+                    log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded),
+                    n_samples + stage_2_pad, frame_size, frame_step, n_threads,
+                    std::cref(filters), std::ref(mel));
+        }
+
+        // main thread
+        log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel);
+
+        for (int iw = 0; iw < n_threads - 1; ++iw) {
+            workers[iw].join();
+        }
+    }
+
+    // clamping and normalization
+    double mmax = -1e20;
+    for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
+        if (mel.data[i] > mmax) {
+            mmax = mel.data[i];
+        }
+    }
+
+    mmax -= 8.0;
+
+    for (int i = 0; i < mel.n_mel*mel.n_len; i++) {
+        if (mel.data[i] < mmax) {
+            mel.data[i] = mmax;
+        }
+
+        mel.data[i] = (mel.data[i] + 4.0)/4.0;
+    }
+
+    // Dump log_mel_spectrogram
+    if (debug) {
+        std::ofstream outFile("log_mel_spectrogram.json");
+        outFile << "[";
+        for (uint64_t i = 0; i < mel.data.size() - 1; i++) {
+            outFile << mel.data[i] << ", ";
+        }
+        outFile << mel.data[mel.data.size() - 1] << "]";
+        outFile.close();
+    }
+
+    return true;
+}
+
+bool preprocess_audio(
+        const float * samples,
+        size_t n_samples,
+        const whisper_filters & filters,
+        std::vector & output) {
+
+    if (n_samples == 0) {
+        // empty audio
+        return false;
+    }
+
+    whisper_mel out_full;
+    bool ok = log_mel_spectrogram(
+                samples,
+                n_samples,
+                COMMON_SAMPLE_RATE,
+                WHISPER_N_FFT,
+                WHISPER_HOP_LENGTH,
+                filters.n_mel,
+                4, // n_threads
+                filters,
+                false, // debug
+                out_full);
+    if (!ok) {
+        return false;
+    }
+
+    // because the cgraph in clip.cpp only accepts 3000 frames each, we need to split the mel
+    // we always expect the mel to have 3000 silent frames at the end
+    // printf("n_len %d\n", out_full.n_len);
+    const size_t frames_per_chunk = 3000;
+    GGML_ASSERT((size_t)out_full.n_len > frames_per_chunk);
+    for (size_t off = 0; off < (size_t)out_full.n_len; off += frames_per_chunk) {
+        int n_len = std::min(frames_per_chunk, (size_t)out_full.n_len - off);
+        if ((size_t)n_len < frames_per_chunk) {
+            break; // last uncomplete chunk will always be a padded chunk, safe to ignore
+        }
+
+        whisper_mel out_chunk;
+        out_chunk.n_len     = n_len;
+        out_chunk.n_mel     = out_full.n_mel;
+        out_chunk.n_len_org = out_full.n_mel; // unused
+        out_chunk.data.reserve(out_chunk.n_mel * out_chunk.n_len);
+
+        for (int i = 0; i < out_full.n_mel; i++) {
+            auto src = out_full.data.begin() + i*out_full.n_len + off;
+            out_chunk.data.insert(out_chunk.data.end(), src, src + frames_per_chunk);
+        }
+
+        output.push_back(std::move(out_chunk));
+    }
+
+    return true;
+}
+
+} // namespace whisper_preprocessor
+
+
+// precalculated mel filter banks
+// values are multiplied by 1000.0 to save space, and will be divided by 1000.0 in the end of the function
+//
+// generated from python code:
+//
+// from numpy import load
+// data = load('mel_filters.npz')
+// lst = data.files
+// for item in lst:
+//   print(item)
+//   print(data[item].shape)
+//   n_mel = data[item].shape[0]
+//   n_fft = data[item].shape[1]
+//   for i, row in enumerate(data[item]):
+//     for j, val in enumerate(row):
+//       val = val * 1000.0
+//       if val != 0:
+//         print(f"data[{i*n_fft + j}] = {val:.6f};")
+
+namespace whisper_precalc_filters {
+
+whisper_preprocessor::whisper_filters get_128_bins() {
+    whisper_preprocessor::whisper_filters filters;
+    filters.n_mel = 128;
+    filters.n_fft = 201;
+    std::vector data(filters.n_mel * filters.n_fft, 0.0f);
+
+    data[1] = 12.37398665;
+    data[202] = 30.39256483;
+    data[404] = 24.74797331;
+    data[605] = 18.01857911;
+    data[807] = 37.12195903;
+    data[1008] = 5.64459199;
+    data[1009] = 6.72939420;
+    data[1210] = 36.03715822;
+    data[1412] = 19.10337992;
+    data[1613] = 23.66316877;
+    data[1815] = 31.47736564;
+    data[2016] = 11.28918398;
+    data[2017] = 1.08480197;
+    data[2218] = 41.68175161;
+    data[2420] = 13.45878839;
+    data[2621] = 29.30776216;
+    data[2823] = 25.83277412;
+    data[3024] = 16.93377644;
+    data[3226] = 38.20675984;
+    data[3427] = 4.55979025;
+    data[3428] = 7.81419594;
+    data[3629] = 34.95235741;
+    data[3831] = 20.18818259;
+    data[4032] = 22.57836796;
+    data[4234] = 32.56217018;
+    data[4435] = 10.20438317;
+    data[4436] = 2.16960395;
+    data[4637] = 40.59694707;
+    data[4839] = 14.54358920;
+    data[5040] = 28.22295949;
+    data[5242] = 26.91757679;
+    data[5443] = 15.84897563;
+    data[5645] = 39.29156065;
+    data[5846] = 3.47498828;
+    data[5847] = 8.89899861;
+    data[6048] = 33.86755288;
+    data[6250] = 21.27298526;
+    data[6451] = 21.49356715;
+    data[6653] = 33.64697099;
+    data[6854] = 9.11958050;
+    data[6855] = 3.25440569;
+    data[7056] = 39.51214626;
+    data[7258] = 15.62839188;
+    data[7459] = 27.13815868;
+    data[7661] = 28.00237760;
+    data[7862] = 14.76417296;
+    data[8064] = 40.37636518;
+    data[8265] = 2.38068704;
+    data[8266] = 10.20263787;
+    data[8467] = 31.61146119;
+    data[8669] = 24.54700135;
+    data[8870] = 15.32919332;
+    data[8871] = 1.66583748;
+    data[9072] = 36.72905266;
+    data[9274] = 20.09709924;
+    data[9475] = 16.93102531;
+    data[9476] = 2.90265540;
+    data[9677] = 32.84499049;
+    data[9879] = 23.52004871;
+    data[10080] = 11.03894413;
+    data[10081] = 10.72582975;
+    data[10282] = 22.71829173;
+    data[10484] = 32.27872774;
+    data[10685] = 0.11626833;
+    data[10686] = 22.85348251;
+    data[10887] = 8.56344029;
+    data[10888] = 14.97978810;
+    data[11089] = 15.51398356;
+    data[11090] = 8.51490628;
+    data[11291] = 21.10680379;
+    data[11292] = 3.32652032;
+    data[11493] = 25.47064796;
+    data[11695] = 27.35907957;
+    data[11896] = 0.65853616;
+    data[11897] = 23.83812517;
+    data[12098] = 3.44359246;
+    data[12099] = 21.22455277;
+    data[12300] = 5.35842171;
+    data[12301] = 19.42555793;
+    data[12502] = 6.49324711;
+    data[12503] = 18.35542172;
+    data[12704] = 6.93138083;
+    data[12705] = 17.93504693;
+    data[12906] = 6.74968259;
+    data[12907] = 18.09151843;
+    data[13108] = 6.01899112;
+    data[13109] = 18.75767298;
+    data[13310] = 4.80452832;
+    data[13311] = 19.87172849;
+    data[13512] = 3.16627859;
+    data[13513] = 21.37690969;
+    data[13514] = 1.25317345;
+    data[13714] = 1.15934468;
+    data[13715] = 20.80361731;
+    data[13716] = 4.04486805;
+    data[13917] = 17.55363122;
+    data[13918] = 7.08320038;
+    data[14119] = 14.07538634;
+    data[14120] = 10.32655034;
+    data[14321] = 10.40921453;
+    data[14322] = 13.73696327;
+    data[14523] = 6.59187697;
+    data[14524] = 17.27988198;
+    data[14525] = 1.46804214;
+    data[14725] = 2.65681883;
+    data[14726] = 18.09193194;
+    data[14727] = 5.85655728;
+    data[14928] = 13.34277913;
+    data[14929] = 10.28267574;
+    data[15130] = 8.56800377;
+    data[15131] = 14.72230814;
+    data[15132] = 1.04039861;
+    data[15332] = 3.79085587;
+    data[15333] = 17.14678481;
+    data[15334] = 6.11609267;
+    data[15535] = 11.75929047;
+    data[15536] = 11.13393717;
+    data[15737] = 6.43857848;
+    data[15738] = 16.07806236;
+    data[15739] = 4.23917221;
+    data[15939] = 1.19989377;
+    data[15940] = 12.75671553;
+    data[15941] = 9.65298992;
+    data[16142] = 7.06935255;
+    data[16143] = 14.94054683;
+    data[16144] = 4.19024844;
+    data[16344] = 1.51483389;
+    data[16345] = 12.00899947;
+    data[16346] = 9.84823331;
+    data[16547] = 6.10224018;
+    data[16548] = 15.33857174;
+    data[16549] = 5.57676842;
+    data[16749] = 0.36827257;
+    data[16750] = 9.89749376;
+    data[16751] = 11.35340426;
+    data[16752] = 2.05122307;
+    data[16952] = 3.89297144;
+    data[16953] = 12.97352277;
+    data[16954] = 8.06631614;
+    data[17155] = 6.74493238;
+    data[17156] = 13.85874674;
+    data[17157] = 5.41190524;
+    data[17357] = 0.74220158;
+    data[17358] = 8.98779090;
+    data[17359] = 11.37871388;
+    data[17360] = 3.32958088;
+    data[17560] = 2.82313535;
+    data[17561] = 10.68049297;
+    data[17562] = 9.43340641;
+    data[17563] = 1.76325557;
+    data[17763] = 4.39018616;
+    data[17764] = 11.87758986;
+    data[17765] = 7.97005836;
+    data[17766] = 0.66104700;
+    data[17966] = 5.49466675;
+    data[17967] = 12.62953598;
+    data[17968] = 6.93987962;
+    data[18169] = 6.18401915;
+    data[18170] = 12.93473132;
+    data[18171] = 6.29778765;
+    data[18371] = 0.02325210;
+    data[18372] = 6.50206627;
+    data[18373] = 12.32661773;
+    data[18374] = 6.00216538;
+    data[18574] = 0.31548753;
+    data[18575] = 6.48925547;
+    data[18576] = 12.04130240;
+    data[18577] = 6.01462880;
+    data[18777] = 0.29979556;
+    data[18778] = 6.18288014;
+    data[18779] = 12.04272825;
+    data[18780] = 6.29981188;
+    data[18781] = 0.55689598;
+    data[18980] = 0.01120471;
+    data[18981] = 5.61729167;
+    data[18982] = 11.22337859;
+    data[18983] = 6.82516303;
+    data[18984] = 1.35264499;
+    data[19184] = 4.82410006;
+    data[19185] = 10.16623247;
+    data[19186] = 7.56075513;
+    data[19187] = 2.34590308;
+    data[19387] = 3.83235747;
+    data[19388] = 8.92296247;
+    data[19389] = 8.47910438;
+    data[19390] = 3.50978645;
+    data[19590] = 2.66873185;
+    data[19591] = 7.51965167;
+    data[19592] = 9.55500547;
+    data[19593] = 4.81966138;
+    data[19594] = 0.08431751;
+    data[19793] = 1.35767367;
+    data[19794] = 5.98019501;
+    data[19795] = 10.60271543;
+    data[19796] = 6.25298498;
+    data[19797] = 1.74059917;
+    data[19997] = 4.32644226;
+    data[19998] = 8.73131864;
+    data[19999] = 7.78916525;
+    data[20000] = 3.48923868;
+    data[20200] = 2.57835095;
+    data[20201] = 6.77582854;
+    data[20202] = 9.40941647;
+    data[20203] = 5.31194592;
+    data[20204] = 1.21447595;
+    data[20403] = 0.75411191;
+    data[20404] = 4.75395704;
+    data[20405] = 8.75380263;
+    data[20406] = 7.19209015;
+    data[20407] = 3.28754401;
+    data[20607] = 2.68179690;
+    data[20608] = 6.49331464;
+    data[20609] = 9.11457930;
+    data[20610] = 5.39387390;
+    data[20611] = 1.67316827;
+    data[20810] = 0.57394296;
+    data[20811] = 4.20600036;
+    data[20812] = 7.83805829;
+    data[20813] = 7.52023002;
+    data[20814] = 3.97470826;
+    data[20815] = 0.42918732;
+    data[21014] = 1.90464477;
+    data[21015] = 5.36569161;
+    data[21016] = 8.82673822;
+    data[21017] = 6.27609482;
+    data[21018] = 2.89750961;
+    data[21218] = 2.89885257;
+    data[21219] = 6.19694078;
+    data[21220] = 8.56699049;
+    data[21221] = 5.34748193;
+    data[21222] = 2.12797290;
+    data[21421] = 0.44750227;
+    data[21422] = 3.59030394;
+    data[21423] = 6.73310598;
+    data[21424] = 7.77023612;
+    data[21425] = 4.70231380;
+    data[21426] = 1.63439126;
+    data[21625] = 1.01536023;
+    data[21626] = 4.01018746;
+    data[21627] = 7.00501446;
+    data[21628] = 7.23442994;
+    data[21629] = 4.31095669;
+    data[21630] = 1.38748321;
+    data[21829] = 1.33348850;
+    data[21830] = 4.18730825;
+    data[21831] = 7.04112789;
+    data[21832] = 6.93188375;
+    data[21833] = 4.14605811;
+    data[21834] = 1.36023236;
+    data[22033] = 1.42879714;
+    data[22034] = 4.14824858;
+    data[22035] = 6.86769979;
+    data[22036] = 6.83705276;
+    data[22037] = 4.18239459;
+    data[22038] = 1.52773573;
+    data[22237] = 1.32610439;
+    data[22238] = 3.91751388;
+    data[22239] = 6.50892360;
+    data[22240] = 6.92639686;
+    data[22241] = 4.39672917;
+    data[22242] = 1.86706171;
+    data[22441] = 1.04827771;
+    data[22442] = 3.51767405;
+    data[22443] = 5.98707050;
+    data[22444] = 7.17824046;
+    data[22445] = 4.76767914;
+    data[22446] = 2.35711760;
+    data[22645] = 0.61636406;
+    data[22646] = 2.96949223;
+    data[22647] = 5.32262027;
+    data[22648] = 7.57265091;
+    data[22649] = 5.27558755;
+    data[22650] = 2.97852419;
+    data[22651] = 0.68146095;
+    data[22849] = 0.04971400;
+    data[22850] = 2.29204819;
+    data[22851] = 4.53438237;
+    data[22852] = 6.77671656;
+    data[22853] = 5.90240723;
+    data[22854] = 3.71349836;
+    data[22855] = 1.52458926;
+    data[23054] = 1.50285335;
+    data[23055] = 3.63961048;
+    data[23056] = 5.77636715;
+    data[23057] = 6.63159089;
+    data[23058] = 4.54574358;
+    data[23059] = 2.45989650;
+    data[23060] = 0.37404924;
+    data[23258] = 0.61795861;
+    data[23259] = 2.65410915;
+    data[23260] = 4.69025923;
+    data[23261] = 6.72641024;
+    data[23262] = 5.46034705;
+    data[23263] = 3.47270933;
+    data[23264] = 1.48507138;
+    data[23463] = 1.59233576;
+    data[23464] = 3.53261665;
+    data[23465] = 5.47289755;
+    data[23466] = 6.44368259;
+    data[23467] = 4.54962999;
+    data[23468] = 2.65557761;
+    data[23469] = 0.76152512;
+    data[23667] = 0.46749352;
+    data[23668] = 2.31641904;
+    data[23669] = 4.16534441;
+    data[23670] = 6.01426978;
+    data[23671] = 5.67844696;
+    data[23672] = 3.87357362;
+    data[23673] = 2.06870004;
+    data[23674] = 0.26382666;
+    data[23872] = 1.05349103;
+    data[23873] = 2.81536230;
+    data[23874] = 4.57723346;
+    data[23875] = 6.33910485;
+    data[23876] = 5.12815686;
+    data[23877] = 3.40826320;
+    data[23878] = 1.68837002;
+    data[24077] = 1.43350090;
+    data[24078] = 3.11241671;
+    data[24079] = 4.79133241;
+    data[24080] = 6.40943693;
+    data[24081] = 4.77052201;
+    data[24082] = 3.13160778;
+    data[24083] = 1.49269309;
+    data[24281] = 0.02932359;
+    data[24282] = 1.62918994;
+    data[24283] = 3.22905602;
+    data[24284] = 4.82892245;
+    data[24285] = 6.14671456;
+    data[24286] = 4.58496623;
+    data[24287] = 3.02321767;
+    data[24288] = 1.46146910;
+    data[24486] = 0.13601698;
+    data[24487] = 1.66055572;
+    data[24488] = 3.18509457;
+    data[24489] = 4.70963307;
+    data[24490] = 6.04072399;
+    data[24491] = 4.55250870;
+    data[24492] = 3.06429295;
+    data[24493] = 1.57607743;
+    data[24494] = 0.08786193;
+    data[24691] = 0.09328097;
+    data[24692] = 1.54603878;
+    data[24693] = 2.99879676;
+    data[24694] = 4.45155473;
+    data[24695] = 5.90431225;
+    data[24696] = 4.65566106;
+    data[24697] = 3.23751615;
+    data[24698] = 1.81937125;
+    data[24699] = 0.40122634;
+    data[24897] = 1.30262633;
+    data[24898] = 2.68698297;
+    data[24899] = 4.07133950;
+    data[24900] = 5.45569602;
+    data[24901] = 4.87832492;
+    data[24902] = 3.52695142;
+    data[24903] = 2.17557792;
+    data[24904] = 0.82420459;
+    data[25102] = 0.94595028;
+    data[25103] = 2.26512621;
+    data[25104] = 3.58430226;
+    data[25105] = 4.90347855;
+    data[25106] = 5.20569785;
+    data[25107] = 3.91795207;
+    data[25108] = 2.63020652;
+    data[25109] = 1.34246063;
+    data[25110] = 0.05471494;
+    data[25307] = 0.49037894;
+    data[25308] = 1.74744334;
+    data[25309] = 3.00450763;
+    data[25310] = 4.26157191;
+    data[25311] = 5.51863620;
+    data[25312] = 4.39707236;
+    data[25313] = 3.16995848;
+    data[25314] = 1.94284460;
+    data[25315] = 0.71573065;
+    data[25513] = 1.14698056;
+    data[25514] = 2.34485767;
+    data[25515] = 3.54273478;
+    data[25516] = 4.74061165;
+    data[25517] = 4.95198462;
+    data[25518] = 3.78264743;
+    data[25519] = 2.61331047;
+    data[25520] = 1.44397374;
+    data[25521] = 0.27463681;
+    data[25718] = 0.47569509;
+    data[25719] = 1.61717169;
+    data[25720] = 2.75864848;
+    data[25721] = 3.90012516;
+    data[25722] = 5.04160160;
+    data[25723] = 4.45712078;
+    data[25724] = 3.34284059;
+    data[25725] = 2.22856039;
+    data[25726] = 1.11428020;
+
+    for (auto & val : data) {
+        val /= 1000.0f;
+    }
+
+    filters.data = std::move(data);
+    return filters;
+}
+
+} // namespace whisper_precalc_filters
diff --git a/tools/mtmd/mtmd-audio.h b/tools/mtmd/mtmd-audio.h
new file mode 100644
index 000000000..b7b940aff
--- /dev/null
+++ b/tools/mtmd/mtmd-audio.h
@@ -0,0 +1,47 @@
+#pragma once
+
+#include "ggml.h"
+
+#include 
+#include 
+#include 
+
+#define WHISPER_ASSERT GGML_ASSERT
+
+#define WHISPER_SAMPLE_RATE 16000
+#define WHISPER_N_FFT       400
+#define WHISPER_HOP_LENGTH  160
+#define WHISPER_CHUNK_SIZE  30
+
+#define COMMON_SAMPLE_RATE 16000
+
+namespace whisper_preprocessor {
+
+struct whisper_mel {
+    int n_len;
+    int n_len_org;
+    int n_mel;
+
+    std::vector data;
+};
+
+struct whisper_filters {
+    int32_t n_mel;
+    int32_t n_fft;
+
+    std::vector data;
+};
+
+bool preprocess_audio(
+        const float * samples,
+        size_t n_samples,
+        const whisper_filters & filters,
+        std::vector & output);
+
+} // namespace whisper_preprocessor
+
+namespace whisper_precalc_filters {
+
+whisper_preprocessor::whisper_filters get_128_bins();
+
+} // namespace whisper_precalc_filters
diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp
new file mode 100644
index 000000000..599e682e0
--- /dev/null
+++ b/tools/mtmd/mtmd-cli.cpp
@@ -0,0 +1,386 @@
+#include "arg.h"
+#include "log.h"
+#include "common.h"
+#include "sampling.h"
+#include "llama.h"
+#include "ggml.h"
+#include "console.h"
+#include "chat.h"
+#include "mtmd.h"
+#include "mtmd-helper.h"
+
+#include 
+#include 
+#include 
+
+#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
+#include 
+#include 
+#elif defined (_WIN32)
+#define WIN32_LEAN_AND_MEAN
+#ifndef NOMINMAX
+#define NOMINMAX
+#endif
+#include 
+#include 
+#endif
+
+// volatile, because of signal being an interrupt
+static volatile bool g_is_generating = false;
+static volatile bool g_is_interrupted = false;
+
+/**
+ * Please note that this is NOT a production-ready stuff.
+ * It is a playground for trying multimodal support in llama.cpp.
+ * For contributors: please keep this code simple and easy to understand.
+ */
+
+static void show_additional_info(int /*argc*/, char ** argv) {
+    LOG(
+        "Experimental CLI for multimodal\n\n"
+        "Usage: %s [options] -m  --mmproj  --image  --audio